1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_dce.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace xla {
32 namespace {
33
34 class HloModuleDceTest : public HloTestBase {
35 protected:
HloModuleDceTest()36 HloModuleDceTest() {}
37
38 // Returns whether the given instruction exists in the given computation.
HasInstruction(const HloComputation & computation,const HloInstruction * instruction)39 bool HasInstruction(const HloComputation& computation,
40 const HloInstruction* instruction) {
41 return absl::c_linear_search(computation.instructions(), instruction);
42 }
43
44 // Returns whether the while instruction with name 'while_name' in
45 // 'computation' passes through its tuple element at 'tuple_index' from
46 // parameter to root instruction.
WhileBodyHasPassThroughTupleElement(const HloComputation * computation,const string & while_name,const int64 tuple_index)47 bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation,
48 const string& while_name,
49 const int64 tuple_index) {
50 for (auto* instruction : computation->instructions()) {
51 if (instruction->opcode() == HloOpcode::kWhile &&
52 instruction->name() == while_name) {
53 auto* while_body_comp = instruction->while_body();
54 auto* while_body_param = while_body_comp->parameter_instruction(0);
55 auto* while_body_root = while_body_comp->root_instruction();
56 if (while_body_root->opcode() != HloOpcode::kTuple) {
57 return false;
58 }
59 auto* operand = while_body_root->operand(tuple_index);
60 if (operand->opcode() == HloOpcode::kGetTupleElement &&
61 operand->tuple_index() == tuple_index &&
62 operand->operand(0) == while_body_param) {
63 return true;
64 }
65 return false;
66 }
67 }
68 return false;
69 }
70 };
71
72 // Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest,WhileWithLiveOutputs)73 TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
74 auto module = ParseHloString(R"(
75 HloModule SimpleLoop
76 SimpleLoop.body {
77 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
78 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
79 constant.1 = s32[] constant(1)
80 add = s32[] add(get-tuple-element.1, constant.1)
81 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
82 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
83 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
84 }
85 SimpleLoop.condition {
86 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
87 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
88 constant.2 = s32[] constant(5)
89 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
90 }
91 ENTRY SimpleLoop {
92 constant.3 = s32[] constant(0)
93 constant.4 = s32[3]{0} constant({0, 1, 2})
94 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
95 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
96 SimpleLoop.condition, body=SimpleLoop.body
97 })")
98 .ValueOrDie();
99
100 HloModuleDCE dce;
101 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
102 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
103 "while", 0));
104 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
105 "while", 1));
106 }
107
108 // Tests a while loop with one unused output (which is used in the while loop
109 // body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest,WhileWithUnusedSideEffectingTupleElement)110 TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
111 auto module = ParseHloString(R"(
112 HloModule SimpleLoop
113 SimpleLoop.body {
114 loop_var.1 = (s32[], f32[]) parameter(0)
115 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
116 constant.1 = s32[] constant(1)
117 add = s32[] add(get-tuple-element.1, constant.1)
118 get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
119 constant.2 = f32[] constant(1.0)
120 rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
121 add.1 = s32[] add(get-tuple-element.2, constant.2)
122 ROOT tuple = (s32[], f32[]) tuple(add, add.1)
123 }
124 SimpleLoop.condition {
125 loop_var.2 = (s32[], f32[]) parameter(0)
126 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
127 constant.3 = s32[] constant(5)
128 ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT
129 }
130 ENTRY SimpleLoop {
131 constant.4 = s32[] constant(0)
132 constant.5 = f32[] constant(0.0)
133 tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5)
134 while = (s32[], f32[]) while(tuple.1), condition=
135 SimpleLoop.condition, body=SimpleLoop.body
136 ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
137 })")
138 .ValueOrDie();
139
140 HloModuleDCE dce;
141 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
142 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
143 "while", 0));
144 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
145 "while", 1));
146 }
147
148 // Tests that a while loop with one dead tuple element at {1} has its while
149 // loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest,OneWhileWithDeadTupleElement)150 TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
151 auto module = ParseHloString(R"(
152 HloModule SimpleLoop
153 SimpleLoop.body {
154 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
155 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
156 constant.1 = s32[] constant(1)
157 add = s32[] add(get-tuple-element.1, constant.1)
158 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
159 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
160 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
161 }
162 SimpleLoop.condition {
163 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
164 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
165 constant.2 = s32[] constant(5)
166 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
167 }
168 ENTRY SimpleLoop {
169 constant.3 = s32[] constant(0)
170 constant.4 = s32[3]{0} constant({0, 1, 2})
171 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
172 while = (s32[], s32[3]{0}) while(tuple.1), condition=
173 SimpleLoop.condition, body=SimpleLoop.body
174 ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
175 })")
176 .ValueOrDie();
177
178 HloModuleDCE dce;
179 // While tuple element {1} should not be pass-through before ModuleDCE.
180 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
181 "while", 1));
182 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
183 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
184 "while", 0));
185 // While tuple element {1} should now be pass-through after ModuleDCE.
186 EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
187 "while", 1));
188 }
189
190 // Tests that a tuple element {1} used by condition computation (which appears
191 // dead in while.body{1} and at while.result{1}) propgates liveness of this
192 // tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest,OneWhileWithTupleElementUsedByCond)193 TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
194 auto module = ParseHloString(R"(
195 HloModule SimpleLoop
196 SimpleLoop.body {
197 loop_var.1 = (s32[], s32[]) parameter(0)
198 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
199 constant.1 = s32[] constant(1)
200 add = s32[] add(get-tuple-element.1, constant.1)
201 get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
202 multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2)
203 ROOT tuple = (s32[], s32[]) tuple(add, multiply)
204 }
205 SimpleLoop.condition {
206 loop_var.2 = (s32[], s32[]) parameter(0)
207 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
208 constant.2 = s32[] constant(5)
209 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
210 }
211 ENTRY SimpleLoop {
212 constant.3 = s32[] constant(0)
213 constant.4 = s32[] constant(0)
214 tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4)
215 while = (s32[], s32[]) while(tuple.1), condition=
216 SimpleLoop.condition, body=SimpleLoop.body
217 ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
218 })")
219 .ValueOrDie();
220
221 HloModuleDCE dce;
222 // While tuple element {1} should not be pass-through before ModuleDCE.
223 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
224 "while", 1));
225 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
226 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
227 "while", 0));
228 // While tuple element {1} still be pass-through after ModuleDCE.
229 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
230 "while", 1));
231 }
232
233 // Tests that HloModuleDCE can remove a dead tuple element at index {1} between
234 // two dependent while loops.
TEST_F(HloModuleDceTest,TwoWhilesWithDeadTupleElement)235 TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
236 auto module = ParseHloString(R"(
237 HloModule SimpleLoop
238 SimpleLoop.body0 {
239 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
240 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
241 constant.1 = s32[] constant(1)
242 add = s32[] add(get-tuple-element.1, constant.1)
243 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
244 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
245 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
246 }
247 SimpleLoop.condition0 {
248 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
249 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
250 constant.2 = s32[] constant(5)
251 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
252 }
253 SimpleLoop.body1 {
254 loop_var.3 = (s32[], s32[3]{0}) parameter(0)
255 get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
256 constant.3 = s32[] constant(1)
257 add.1 = s32[] add(get-tuple-element.4, constant.3)
258 get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
259 multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
260 ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
261 }
262 SimpleLoop.condition1 {
263 loop_var.4 = (s32[], s32[3]{0}) parameter(0)
264 get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
265 constant.4 = s32[] constant(5)
266 ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
267 }
268 ENTRY SimpleLoop {
269 constant.5 = s32[] constant(0)
270 constant.6 = s32[3]{0} constant({0, 1, 2})
271 tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
272 while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
273 SimpleLoop.condition0, body=SimpleLoop.body0
274 get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0
275 tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
276 while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
277 SimpleLoop.condition1, body=SimpleLoop.body1
278 ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
279 })")
280 .ValueOrDie();
281
282 HloModuleDCE dce;
283 // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements.
284 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
285 "while.1", 1));
286 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
287 "while.2", 1));
288 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
289 // After HloModuleDCE while.1 and while.2 should have pass-thru elements,
290 // after being modified to pass through unused tuple element {1}.
291 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
292 "while.1", 0));
293 EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
294 "while.1", 1));
295 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
296 "while.2", 0));
297 EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
298 "while.2", 1));
299 }
300
301 // Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
302 // while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest,TwoWhilesWithDeadTupleElementSwizzled)303 TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
304 auto module = ParseHloString(R"(
305 HloModule SimpleLoop
306 SimpleLoop.body0 {
307 loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
308 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1
309 constant.1 = s32[] constant(1)
310 add = s32[] add(get-tuple-element.1, constant.1)
311 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0
312 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
313 ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add)
314 }
315 SimpleLoop.condition0 {
316 loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
317 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
318 constant.2 = s32[] constant(5)
319 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
320 }
321 SimpleLoop.body1 {
322 loop_var.3 = (s32[], s32[3]{0}) parameter(0)
323 get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
324 constant.3 = s32[] constant(1)
325 add.1 = s32[] add(get-tuple-element.4, constant.3)
326 get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
327 multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
328 ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
329 }
330 SimpleLoop.condition1 {
331 loop_var.4 = (s32[], s32[3]{0}) parameter(0)
332 get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
333 constant.4 = s32[] constant(5)
334 ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
335 }
336 ENTRY SimpleLoop {
337 constant.5 = s32[] constant(0)
338 constant.6 = s32[3]{0} constant({0, 1, 2})
339 tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5)
340 while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition=
341 SimpleLoop.condition0, body=SimpleLoop.body0
342 get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1
343 tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
344 while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
345 SimpleLoop.condition1, body=SimpleLoop.body1
346 ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
347 })")
348 .ValueOrDie();
349
350 HloModuleDCE dce;
351 // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru.
352 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
353 "while.1", 0));
354 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
355 "while.2", 1));
356 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
357 // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
358 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
359 "while.1", 1));
360 EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
361 "while.1", 0));
362 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
363 "while.2", 0));
364 EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
365 "while.2", 1));
366 }
367
368 // Tests that a while whose body has outfeed operations is not DCE-ed.
TEST_F(HloModuleDceTest,WhileWithOutfeed)369 TEST_F(HloModuleDceTest, WhileWithOutfeed) {
370 auto module = ParseHloString(R"(
371 HloModule OutfeedLoop
372 WhileBody {
373 body_param = (s32[]) parameter(0)
374 token0 = token[] after-all()
375 constant.2 = s32[] constant(2)
376 outfeed_tuple = (s32[]) outfeed(constant.2, token0)
377 get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
378 constant.1 = s32[] constant(1)
379 add = s32[] add(get-tuple-element.1, constant.1)
380 ROOT tuple = (s32[]) tuple(add)
381 }
382 WhileCondition {
383 cond_param = (s32[]) parameter(0)
384 get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
385 constant.2 = s32[] constant(10)
386 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
387 }
388 ENTRY SimpleLoop {
389 constant.3 = s32[] constant(0)
390 tuple.1 = (s32[]) tuple(constant.3)
391 while = (s32[]) while(tuple.1), condition=WhileCondition,
392 body=WhileBody
393 ROOT rtuple = () tuple()
394 })")
395 .ValueOrDie();
396
397 HloModuleDCE dce;
398 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
399 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
400 "while", 0));
401 }
402
403 // Tests that if a loop variable is not referenced outside of a kWhile, the loop
404 // variable changes are not elided within the loop body, if the condition
405 // computation uses them.
TEST_F(HloModuleDceTest,WhileWithOnlyLoopVariableBumping)406 TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
407 auto module = ParseHloString(R"(
408 HloModule InfiniteLoop
409 WhileBody {
410 body_param = (s32[], s32[]) parameter(0)
411 get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
412 get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
413 constant.1 = s32[] constant(1)
414 add = s32[] add(get-tuple-element.1, constant.1)
415 ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
416 }
417 WhileCondition {
418 cond_param = (s32[], s32[]) parameter(0)
419 get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
420 constant.2 = s32[] constant(10)
421 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
422 }
423 ENTRY SimpleLoop {
424 p0 = (s32[]) parameter(0)
425 get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
426 constant.3 = s32[] constant(0)
427 tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
428 while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
429 body=WhileBody
430 ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
431 })")
432 .ValueOrDie();
433
434 HloModuleDCE dce;
435 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
436 EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
437 "while", 0));
438 }
439
440 } // namespace
441 } // namespace xla
442