1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dce.h"
17
18 #include <memory>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_utils.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 class HloDceTest : public HloTestBase {
40 protected:
HloDceTest()41 HloDceTest() {}
42
43 // Returns whether the given instruction exists in the given computation.
HasInstruction(const HloComputation & computation,const HloInstruction * instruction)44 bool HasInstruction(const HloComputation& computation,
45 const HloInstruction* instruction) {
46 return absl::c_linear_search(computation.instructions(), instruction);
47 }
48 };
49
TEST_F(HloDceTest,NoDeadCode)50 TEST_F(HloDceTest, NoDeadCode) {
51 // Verify that no dead code is removed from a computation with no dead code.
52 auto builder = HloComputation::Builder(TestName());
53 auto constant1 = builder.AddInstruction(
54 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
55 auto constant2 = builder.AddInstruction(
56 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
57 builder.AddInstruction(HloInstruction::CreateBinary(
58 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
59
60 auto module = CreateNewVerifiedModule();
61 auto computation = module->AddEntryComputation(builder.Build());
62
63 EXPECT_EQ(3, computation->instruction_count());
64
65 HloDCE dce;
66 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
67
68 EXPECT_EQ(3, computation->instruction_count());
69 }
70
TEST_F(HloDceTest,InstructionsWithSideEffect)71 TEST_F(HloDceTest, InstructionsWithSideEffect) {
72 // Verify that side-effect instructions (Send in this test) are not removed.
73 auto builder = HloComputation::Builder(TestName());
74 auto constant = builder.AddInstruction(
75 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
76 auto token = builder.AddInstruction(HloInstruction::CreateToken());
77 builder.AddInstruction(
78 HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
79 builder.AddInstruction(HloInstruction::CreateTuple({}));
80
81 auto module = CreateNewUnverifiedModule();
82 auto computation = module->AddEntryComputation(builder.Build());
83
84 EXPECT_EQ(4, computation->instruction_count());
85
86 HloDCE dce;
87 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
88
89 EXPECT_EQ(4, computation->instruction_count());
90 }
91
TEST_F(HloDceTest,DeadParameters)92 TEST_F(HloDceTest, DeadParameters) {
93 // Verify that dead parameters are not removed, but use of the dead parameters
94 // are.
95 auto builder = HloComputation::Builder(TestName());
96 auto live_param = builder.AddInstruction(HloInstruction::CreateParameter(
97 0, ShapeUtil::MakeShape(F32, {}), "live_param"));
98 auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter(
99 1, ShapeUtil::MakeShape(F32, {}), "dead_param1"));
100 builder.AddInstruction(HloInstruction::CreateParameter(
101 2, ShapeUtil::MakeShape(F32, {}), "dead_param2"));
102
103 // This is a dead negate instruction.
104 builder.AddInstruction(HloInstruction::CreateUnary(
105 dead_param1->shape(), HloOpcode::kNegate, dead_param1));
106
107 // This negate is not dead because it is the root.
108 builder.AddInstruction(HloInstruction::CreateUnary(
109 live_param->shape(), HloOpcode::kNegate, live_param));
110
111 auto module = CreateNewVerifiedModule();
112 auto computation = module->AddEntryComputation(builder.Build());
113
114 EXPECT_EQ(5, computation->instruction_count());
115 EXPECT_EQ(1, dead_param1->user_count());
116
117 HloDCE dce;
118 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
119
120 EXPECT_EQ(4, computation->instruction_count());
121 EXPECT_EQ(0, dead_param1->user_count());
122 }
123
TEST_F(HloDceTest,ControlDependencies)124 TEST_F(HloDceTest, ControlDependencies) {
125 // Verify that instructions with control dependencies are not removed.
126 auto builder = HloComputation::Builder(TestName());
127 auto constant1 = builder.AddInstruction(
128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
129 auto constant2 = builder.AddInstruction(
130 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
131
132 // Create two dead instructions: a negate and an add.
133 auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
134 constant1->shape(), HloOpcode::kNegate, constant1));
135 auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
136 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
137
138 // Create the same two instructions again, but these will have a control
139 // dependency added.
140 auto dead_negate_with_control_dep =
141 builder.AddInstruction(HloInstruction::CreateUnary(
142 constant1->shape(), HloOpcode::kNegate, constant1));
143 auto dead_add_with_control_dep =
144 builder.AddInstruction(HloInstruction::CreateBinary(
145 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
146
147 // Create a root so the previously added instruction is dead.
148 builder.AddInstruction(HloInstruction::CreateBinary(
149 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
150
151 auto module = CreateNewVerifiedModule();
152 auto computation = module->AddEntryComputation(builder.Build());
153
154 // Add a control dependency between two instructions.
155 TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo(
156 dead_add_with_control_dep));
157
158 EXPECT_EQ(7, computation->instruction_count());
159 EXPECT_TRUE(HasInstruction(*computation, dead_negate));
160 EXPECT_TRUE(HasInstruction(*computation, dead_add));
161 EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
162 EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
163
164 HloDCE dce;
165 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
166
167 EXPECT_EQ(5, computation->instruction_count());
168 EXPECT_FALSE(HasInstruction(*computation, dead_negate));
169 EXPECT_FALSE(HasInstruction(*computation, dead_add));
170 EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
171 EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
172 }
173
174 // Tests that a dead call instruction is removed.
TEST_F(HloDceTest,DeadInstructionWithCalledComputation)175 TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
176 auto module = CreateNewVerifiedModule();
177 Shape shape = ShapeUtil::MakeShape(F32, {});
178
179 // Called computation for the call instruction.
180 auto callee_builder = HloComputation::Builder(TestName() + "-callee");
181 {
182 auto param = callee_builder.AddInstruction(
183 HloInstruction::CreateParameter(0, shape, "param"));
184 callee_builder.AddInstruction(
185 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
186 }
187 auto called_computation =
188 module->AddEmbeddedComputation(callee_builder.Build());
189
190 // Entry computation with a call instruction.
191 auto builder = HloComputation::Builder(TestName());
192 auto param = builder.AddInstruction(
193 HloInstruction::CreateParameter(0, shape, "param"));
194 auto dead_call = builder.AddInstruction(
195 HloInstruction::CreateCall(shape, {param}, called_computation));
196 builder.AddInstruction(
197 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
198 auto computation = module->AddEntryComputation(builder.Build());
199
200 EXPECT_EQ(3, computation->instruction_count());
201 EXPECT_EQ(2, param->user_count());
202 EXPECT_EQ(0, dead_call->user_count());
203 EXPECT_TRUE(HasInstruction(*computation, dead_call));
204
205 HloDCE dce;
206 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
207
208 EXPECT_EQ(2, computation->instruction_count());
209 EXPECT_EQ(1, param->user_count());
210 EXPECT_FALSE(HasInstruction(*computation, dead_call));
211 }
212
213 // Tests that a while instruction with an infeed (effectul instruction) in its
214 // body is not removed, even its user count is 0.
TEST_F(HloDceTest,CalledComputationWithSideEffect)215 TEST_F(HloDceTest, CalledComputationWithSideEffect) {
216 auto module = CreateNewUnverifiedModule();
217 Shape shape = ShapeUtil::MakeShape(F32, {});
218
219 // Condition computation of a while instruction.
220 auto cond_builder = HloComputation::Builder(TestName() + "-cond");
221 {
222 auto param = cond_builder.AddInstruction(
223 HloInstruction::CreateParameter(0, shape, "cond_param"));
224 auto constant = cond_builder.AddInstruction(
225 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
226 cond_builder.AddInstruction(
227 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
228 constant, ComparisonDirection::kLt));
229 }
230 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
231
232 // Body computation of a while instruction.
233 auto body_builder = HloComputation::Builder(TestName() + "-body");
234 {
235 auto param = body_builder.AddInstruction(
236 HloInstruction::CreateParameter(0, shape, "param"));
237 auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
238 auto infeed = body_builder.AddInstruction(
239 HloInstruction::CreateInfeed(shape, token, ""));
240 body_builder.AddInstruction(
241 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
242 }
243 auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
244
245 // Entry computation with a while instruction and a negate on the parameter.
246 auto builder = HloComputation::Builder(TestName());
247 auto param = builder.AddInstruction(
248 HloInstruction::CreateParameter(0, shape, "param"));
249 auto live_while = builder.AddInstruction(HloInstruction::CreateWhile(
250 shape, cond_computation, body_computation, param));
251 builder.AddInstruction(
252 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
253 auto computation = module->AddEntryComputation(builder.Build());
254
255 // Check the while instruction is not removed even if its user count is 0.
256 EXPECT_EQ(3, computation->instruction_count());
257 EXPECT_EQ(2, param->user_count());
258 EXPECT_EQ(0, live_while->user_count());
259 EXPECT_TRUE(HasInstruction(*computation, live_while));
260
261 HloDCE dce;
262 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
263
264 EXPECT_EQ(3, computation->instruction_count());
265 EXPECT_EQ(2, param->user_count());
266 EXPECT_EQ(0, live_while->user_count());
267 EXPECT_TRUE(HasInstruction(*computation, live_while));
268 }
269
270 // Tests that a nested call instruction with a side effect is not removed.
TEST_F(HloDceTest,CalledComputationWithNestedSideEffect)271 TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
272 auto module = CreateNewUnverifiedModule();
273 Shape shape = ShapeUtil::MakeShape(F32, {});
274
275 // Nested called computation with a side effect.
276 auto nested_callee_builder =
277 HloComputation::Builder(TestName() + "-nested_callee");
278 {
279 auto param = nested_callee_builder.AddInstruction(
280 HloInstruction::CreateParameter(0, shape, "param"));
281 auto token =
282 nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
283 nested_callee_builder.AddInstruction(
284 HloInstruction::CreateOutfeed(shape, param, token, ""));
285 }
286 auto nested_called_computation =
287 module->AddEmbeddedComputation(nested_callee_builder.Build());
288
289 // Outer called computation that calls the nested computation.
290 auto callee_builder = HloComputation::Builder(TestName() + "-callee");
291 {
292 auto param = callee_builder.AddInstruction(
293 HloInstruction::CreateParameter(0, shape, "param"));
294 callee_builder.AddInstruction(
295 HloInstruction::CreateCall(shape, {param}, nested_called_computation));
296 }
297 auto called_computation =
298 module->AddEmbeddedComputation(callee_builder.Build());
299
300 // Entry computation with a call instruction.
301 auto builder = HloComputation::Builder(TestName());
302 auto param = builder.AddInstruction(
303 HloInstruction::CreateParameter(0, shape, "param"));
304 auto live_call = builder.AddInstruction(
305 HloInstruction::CreateCall(shape, {param}, called_computation));
306 builder.AddInstruction(
307 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
308 auto computation = module->AddEntryComputation(builder.Build());
309
310 EXPECT_EQ(3, computation->instruction_count());
311 EXPECT_EQ(2, param->user_count());
312 EXPECT_EQ(0, live_call->user_count());
313 EXPECT_TRUE(HasInstruction(*computation, live_call));
314
315 HloDCE dce;
316 EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
317
318 EXPECT_EQ(3, computation->instruction_count());
319 EXPECT_EQ(2, param->user_count());
320 EXPECT_EQ(0, live_call->user_count());
321 EXPECT_TRUE(HasInstruction(*computation, live_call));
322 }
323
TEST_F(HloDceTest,RemoveDeadSubcomputation)324 TEST_F(HloDceTest, RemoveDeadSubcomputation) {
325 auto module = CreateNewVerifiedModule();
326 HloComputation::Builder builder(TestName());
327
328 HloComputation::Builder subcomp_builder("reduction_subcomp");
329 {
330 auto* param0 =
331 subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
332 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
333 auto* param1 =
334 subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
335 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
336 subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
337 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
338 }
339 auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
340
341 // Create a dead reduce instruction.
342 builder.AddInstruction(HloInstruction::CreateReduce(
343 ShapeUtil::MakeShape(F32, {1}),
344 builder.AddInstruction(HloInstruction::CreateParameter(
345 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
346 builder.AddInstruction(
347 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
348 /*dimensions_to_reduce=*/{0}, reduce_subcomp));
349
350 // Add another instruction as the root of the computation.
351 builder.AddInstruction(
352 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
353
354 module->AddEntryComputation(builder.Build());
355 EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
356
357 HloDCE dce;
358 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
359
360 // We should have DCE'ed the reduction computation along with the reduction
361 // instruction.
362 EXPECT_EQ(module->MakeComputationPostOrder().size(), 1);
363 }
364
TEST_F(HloDceTest,KeepUsedSubcomputation)365 TEST_F(HloDceTest, KeepUsedSubcomputation) {
366 auto module = CreateNewUnverifiedModule();
367 HloComputation::Builder builder(TestName());
368
369 HloComputation::Builder subcomp_builder("reduction_subcomp");
370 {
371 auto* param0 =
372 subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
373 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
374 auto* param1 =
375 subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
376 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
377 subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
378 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
379 }
380 auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
381
382 // Create a dead reduce instruction.
383 builder.AddInstruction(HloInstruction::CreateReduce(
384 ShapeUtil::MakeShape(F32, {1}),
385 builder.AddInstruction(HloInstruction::CreateParameter(
386 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
387 builder.AddInstruction(
388 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
389 /*dimensions_to_reduce=*/{0}, reduce_subcomp));
390
391 // Add another instruction as the root of the computation that also uses
392 // reduce_subcomp.
393 builder.AddInstruction(HloInstruction::CreateReduce(
394 ShapeUtil::MakeShape(F32, {1}),
395 builder.AddInstruction(HloInstruction::CreateParameter(
396 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
397 builder.AddInstruction(
398 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
399 /*dimensions_to_reduce=*/{0}, reduce_subcomp));
400
401 module->AddEntryComputation(builder.Build());
402 EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
403
404 HloDCE dce;
405 EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
406
407 // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of
408 // its users.
409 EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
410 }
411
412 } // namespace
413 } // namespace xla
414