1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_dce.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace xla {
32 namespace {
33 
34 class HloModuleDceTest : public HloTestBase {
35  protected:
HloModuleDceTest()36   HloModuleDceTest() {}
37 
38   // Returns whether the given instruction exists in the given computation.
HasInstruction(const HloComputation & computation,const HloInstruction * instruction)39   bool HasInstruction(const HloComputation& computation,
40                       const HloInstruction* instruction) {
41     return absl::c_linear_search(computation.instructions(), instruction);
42   }
43 
44   // Returns whether the while instruction with name 'while_name' in
45   // 'computation' passes through its tuple element at 'tuple_index' from
46   // parameter to root instruction.
WhileBodyHasPassThroughTupleElement(const HloComputation * computation,const string & while_name,const int64 tuple_index)47   bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation,
48                                            const string& while_name,
49                                            const int64 tuple_index) {
50     for (auto* instruction : computation->instructions()) {
51       if (instruction->opcode() == HloOpcode::kWhile &&
52           instruction->name() == while_name) {
53         auto* while_body_comp = instruction->while_body();
54         auto* while_body_param = while_body_comp->parameter_instruction(0);
55         auto* while_body_root = while_body_comp->root_instruction();
56         if (while_body_root->opcode() != HloOpcode::kTuple) {
57           return false;
58         }
59         auto* operand = while_body_root->operand(tuple_index);
60         if (operand->opcode() == HloOpcode::kGetTupleElement &&
61             operand->tuple_index() == tuple_index &&
62             operand->operand(0) == while_body_param) {
63           return true;
64         }
65         return false;
66       }
67     }
68     return false;
69   }
70 };
71 
72 // Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest,WhileWithLiveOutputs)73 TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
74   auto module = ParseHloString(R"(
75   HloModule SimpleLoop
76   SimpleLoop.body {
77     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
78     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
79     constant.1 = s32[] constant(1)
80     add = s32[] add(get-tuple-element.1, constant.1)
81     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
82     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
83     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
84   }
85   SimpleLoop.condition {
86     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
87     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
88     constant.2 = s32[] constant(5)
89     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
90   }
91   ENTRY SimpleLoop {
92     constant.3 = s32[] constant(0)
93     constant.4 = s32[3]{0} constant({0, 1, 2})
94     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
95     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
96       SimpleLoop.condition, body=SimpleLoop.body
97   })")
98                     .ValueOrDie();
99 
100   HloModuleDCE dce;
101   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
102   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
103                                                    "while", 0));
104   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
105                                                    "while", 1));
106 }
107 
108 // Tests a while loop with one unused output (which is used in the while loop
109 // body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest,WhileWithUnusedSideEffectingTupleElement)110 TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
111   auto module = ParseHloString(R"(
112   HloModule SimpleLoop
113   SimpleLoop.body {
114     loop_var.1 = (s32[], f32[]) parameter(0)
115     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
116     constant.1 = s32[] constant(1)
117     add = s32[] add(get-tuple-element.1, constant.1)
118     get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
119     constant.2 = f32[] constant(1.0)
120     rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
121     add.1 = s32[] add(get-tuple-element.2, constant.2)
122     ROOT tuple = (s32[], f32[]) tuple(add, add.1)
123   }
124   SimpleLoop.condition {
125     loop_var.2 = (s32[], f32[]) parameter(0)
126     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
127     constant.3 = s32[] constant(5)
128     ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT
129   }
130   ENTRY SimpleLoop {
131     constant.4 = s32[] constant(0)
132     constant.5 = f32[] constant(0.0)
133     tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5)
134     while = (s32[], f32[]) while(tuple.1), condition=
135       SimpleLoop.condition, body=SimpleLoop.body
136     ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
137   })")
138                     .ValueOrDie();
139 
140   HloModuleDCE dce;
141   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
142   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
143                                                    "while", 0));
144   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
145                                                    "while", 1));
146 }
147 
148 // Tests that a while loop with one dead tuple element at {1} has its while
149 // loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest,OneWhileWithDeadTupleElement)150 TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
151   auto module = ParseHloString(R"(
152   HloModule SimpleLoop
153   SimpleLoop.body {
154     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
155     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
156     constant.1 = s32[] constant(1)
157     add = s32[] add(get-tuple-element.1, constant.1)
158     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
159     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
160     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
161   }
162   SimpleLoop.condition {
163     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
164     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
165     constant.2 = s32[] constant(5)
166     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
167   }
168   ENTRY SimpleLoop {
169     constant.3 = s32[] constant(0)
170     constant.4 = s32[3]{0} constant({0, 1, 2})
171     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
172     while = (s32[], s32[3]{0}) while(tuple.1), condition=
173       SimpleLoop.condition, body=SimpleLoop.body
174     ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
175   })")
176                     .ValueOrDie();
177 
178   HloModuleDCE dce;
179   // While tuple element {1} should not be pass-through before ModuleDCE.
180   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
181                                                    "while", 1));
182   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
183   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
184                                                    "while", 0));
185   // While tuple element {1} should now be pass-through after ModuleDCE.
186   EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
187                                                   "while", 1));
188 }
189 
190 // Tests that a tuple element {1} used by condition computation (which appears
191 // dead in while.body{1} and at while.result{1}) propgates liveness of this
192 // tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest,OneWhileWithTupleElementUsedByCond)193 TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
194   auto module = ParseHloString(R"(
195   HloModule SimpleLoop
196   SimpleLoop.body {
197     loop_var.1 = (s32[], s32[]) parameter(0)
198     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
199     constant.1 = s32[] constant(1)
200     add = s32[] add(get-tuple-element.1, constant.1)
201     get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
202     multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2)
203     ROOT tuple = (s32[], s32[]) tuple(add, multiply)
204   }
205   SimpleLoop.condition {
206     loop_var.2 = (s32[], s32[]) parameter(0)
207     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
208     constant.2 = s32[] constant(5)
209     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
210   }
211   ENTRY SimpleLoop {
212     constant.3 = s32[] constant(0)
213     constant.4 = s32[] constant(0)
214     tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4)
215     while = (s32[], s32[]) while(tuple.1), condition=
216       SimpleLoop.condition, body=SimpleLoop.body
217     ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
218   })")
219                     .ValueOrDie();
220 
221   HloModuleDCE dce;
222   // While tuple element {1} should not be pass-through before ModuleDCE.
223   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
224                                                    "while", 1));
225   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
226   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
227                                                    "while", 0));
228   // While tuple element {1} still be pass-through after ModuleDCE.
229   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
230                                                    "while", 1));
231 }
232 
233 // Tests that HloModuleDCE can remove a dead tuple element at index {1} between
234 // two dependent while loops.
TEST_F(HloModuleDceTest,TwoWhilesWithDeadTupleElement)235 TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
236   auto module = ParseHloString(R"(
237   HloModule SimpleLoop
238   SimpleLoop.body0 {
239     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
240     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
241     constant.1 = s32[] constant(1)
242     add = s32[] add(get-tuple-element.1, constant.1)
243     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
244     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
245     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
246   }
247   SimpleLoop.condition0 {
248     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
249     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
250     constant.2 = s32[] constant(5)
251     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
252   }
253   SimpleLoop.body1 {
254     loop_var.3 = (s32[], s32[3]{0}) parameter(0)
255     get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
256     constant.3 = s32[] constant(1)
257     add.1 = s32[] add(get-tuple-element.4, constant.3)
258     get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
259     multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
260     ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
261   }
262   SimpleLoop.condition1 {
263     loop_var.4 = (s32[], s32[3]{0}) parameter(0)
264     get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
265     constant.4 = s32[] constant(5)
266     ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
267   }
268   ENTRY SimpleLoop {
269     constant.5 = s32[] constant(0)
270     constant.6 = s32[3]{0} constant({0, 1, 2})
271     tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
272     while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
273       SimpleLoop.condition0, body=SimpleLoop.body0
274     get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0
275     tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
276     while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
277       SimpleLoop.condition1, body=SimpleLoop.body1
278     ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
279   })")
280                     .ValueOrDie();
281 
282   HloModuleDCE dce;
283   // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements.
284   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
285                                                    "while.1", 1));
286   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
287                                                    "while.2", 1));
288   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
289   // After HloModuleDCE while.1 and while.2 should have pass-thru elements,
290   // after being modified to pass through unused tuple element {1}.
291   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
292                                                    "while.1", 0));
293   EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
294                                                   "while.1", 1));
295   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
296                                                    "while.2", 0));
297   EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
298                                                   "while.2", 1));
299 }
300 
301 // Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
302 // while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest,TwoWhilesWithDeadTupleElementSwizzled)303 TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
304   auto module = ParseHloString(R"(
305   HloModule SimpleLoop
306   SimpleLoop.body0 {
307     loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
308     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1
309     constant.1 = s32[] constant(1)
310     add = s32[] add(get-tuple-element.1, constant.1)
311     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0
312     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
313     ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add)
314   }
315   SimpleLoop.condition0 {
316     loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
317     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
318     constant.2 = s32[] constant(5)
319     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
320   }
321   SimpleLoop.body1 {
322     loop_var.3 = (s32[], s32[3]{0}) parameter(0)
323     get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
324     constant.3 = s32[] constant(1)
325     add.1 = s32[] add(get-tuple-element.4, constant.3)
326     get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
327     multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
328     ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
329   }
330   SimpleLoop.condition1 {
331     loop_var.4 = (s32[], s32[3]{0}) parameter(0)
332     get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
333     constant.4 = s32[] constant(5)
334     ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
335   }
336   ENTRY SimpleLoop {
337     constant.5 = s32[] constant(0)
338     constant.6 = s32[3]{0} constant({0, 1, 2})
339     tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5)
340     while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition=
341       SimpleLoop.condition0, body=SimpleLoop.body0
342     get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1
343     tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
344     while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
345       SimpleLoop.condition1, body=SimpleLoop.body1
346     ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
347   })")
348                     .ValueOrDie();
349 
350   HloModuleDCE dce;
351   // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru.
352   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
353                                                    "while.1", 0));
354   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
355                                                    "while.2", 1));
356   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
357   // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
358   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
359                                                    "while.1", 1));
360   EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
361                                                   "while.1", 0));
362   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
363                                                    "while.2", 0));
364   EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
365                                                   "while.2", 1));
366 }
367 
368 // Tests that a while whose body has outfeed operations is not DCE-ed.
TEST_F(HloModuleDceTest,WhileWithOutfeed)369 TEST_F(HloModuleDceTest, WhileWithOutfeed) {
370   auto module = ParseHloString(R"(
371   HloModule OutfeedLoop
372   WhileBody {
373     body_param = (s32[]) parameter(0)
374     token0 = token[] after-all()
375     constant.2 = s32[] constant(2)
376     outfeed_tuple = (s32[]) outfeed(constant.2, token0)
377     get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
378     constant.1 = s32[] constant(1)
379     add = s32[] add(get-tuple-element.1, constant.1)
380     ROOT tuple = (s32[]) tuple(add)
381   }
382   WhileCondition {
383     cond_param = (s32[]) parameter(0)
384     get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
385     constant.2 = s32[] constant(10)
386     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
387   }
388   ENTRY SimpleLoop {
389     constant.3 = s32[] constant(0)
390     tuple.1 = (s32[]) tuple(constant.3)
391     while = (s32[]) while(tuple.1), condition=WhileCondition,
392       body=WhileBody
393     ROOT rtuple = () tuple()
394   })")
395                     .ValueOrDie();
396 
397   HloModuleDCE dce;
398   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
399   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
400                                                    "while", 0));
401 }
402 
403 // Tests that if a loop variable is not referenced outside of a kWhile, the loop
404 // variable changes are not elided within the loop body, if the condition
405 // computation uses them.
TEST_F(HloModuleDceTest,WhileWithOnlyLoopVariableBumping)406 TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
407   auto module = ParseHloString(R"(
408   HloModule InfiniteLoop
409   WhileBody {
410     body_param = (s32[], s32[]) parameter(0)
411     get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
412     get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
413     constant.1 = s32[] constant(1)
414     add = s32[] add(get-tuple-element.1, constant.1)
415     ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
416   }
417   WhileCondition {
418     cond_param = (s32[], s32[]) parameter(0)
419     get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
420     constant.2 = s32[] constant(10)
421     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
422   }
423   ENTRY SimpleLoop {
424     p0 = (s32[]) parameter(0)
425     get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
426     constant.3 = s32[] constant(0)
427     tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
428     while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
429       body=WhileBody
430     ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
431   })")
432                     .ValueOrDie();
433 
434   HloModuleDCE dce;
435   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
436   EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
437                                                    "while", 0));
438 }
439 
440 }  // namespace
441 }  // namespace xla
442