1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dce.h"
17 
18 #include <memory>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_utils.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 class HloDceTest : public HloTestBase {
40  protected:
HloDceTest()41   HloDceTest() {}
42 
43   // Returns whether the given instruction exists in the given computation.
HasInstruction(const HloComputation & computation,const HloInstruction * instruction)44   bool HasInstruction(const HloComputation& computation,
45                       const HloInstruction* instruction) {
46     return absl::c_linear_search(computation.instructions(), instruction);
47   }
48 };
49 
TEST_F(HloDceTest,NoDeadCode)50 TEST_F(HloDceTest, NoDeadCode) {
51   // Verify that no dead code is removed from a computation with no dead code.
52   auto builder = HloComputation::Builder(TestName());
53   auto constant1 = builder.AddInstruction(
54       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
55   auto constant2 = builder.AddInstruction(
56       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
57   builder.AddInstruction(HloInstruction::CreateBinary(
58       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
59 
60   auto module = CreateNewVerifiedModule();
61   auto computation = module->AddEntryComputation(builder.Build());
62 
63   EXPECT_EQ(3, computation->instruction_count());
64 
65   HloDCE dce;
66   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
67 
68   EXPECT_EQ(3, computation->instruction_count());
69 }
70 
TEST_F(HloDceTest,InstructionsWithSideEffect)71 TEST_F(HloDceTest, InstructionsWithSideEffect) {
72   // Verify that side-effect instructions (Send in this test) are not removed.
73   auto builder = HloComputation::Builder(TestName());
74   auto constant = builder.AddInstruction(
75       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
76   auto token = builder.AddInstruction(HloInstruction::CreateToken());
77   builder.AddInstruction(
78       HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
79   builder.AddInstruction(HloInstruction::CreateTuple({}));
80 
81   auto module = CreateNewUnverifiedModule();
82   auto computation = module->AddEntryComputation(builder.Build());
83 
84   EXPECT_EQ(4, computation->instruction_count());
85 
86   HloDCE dce;
87   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
88 
89   EXPECT_EQ(4, computation->instruction_count());
90 }
91 
TEST_F(HloDceTest,DeadParameters)92 TEST_F(HloDceTest, DeadParameters) {
93   // Verify that dead parameters are not removed, but use of the dead parameters
94   // are.
95   auto builder = HloComputation::Builder(TestName());
96   auto live_param = builder.AddInstruction(HloInstruction::CreateParameter(
97       0, ShapeUtil::MakeShape(F32, {}), "live_param"));
98   auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter(
99       1, ShapeUtil::MakeShape(F32, {}), "dead_param1"));
100   builder.AddInstruction(HloInstruction::CreateParameter(
101       2, ShapeUtil::MakeShape(F32, {}), "dead_param2"));
102 
103   // This is a dead negate instruction.
104   builder.AddInstruction(HloInstruction::CreateUnary(
105       dead_param1->shape(), HloOpcode::kNegate, dead_param1));
106 
107   // This negate is not dead because it is the root.
108   builder.AddInstruction(HloInstruction::CreateUnary(
109       live_param->shape(), HloOpcode::kNegate, live_param));
110 
111   auto module = CreateNewVerifiedModule();
112   auto computation = module->AddEntryComputation(builder.Build());
113 
114   EXPECT_EQ(5, computation->instruction_count());
115   EXPECT_EQ(1, dead_param1->user_count());
116 
117   HloDCE dce;
118   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
119 
120   EXPECT_EQ(4, computation->instruction_count());
121   EXPECT_EQ(0, dead_param1->user_count());
122 }
123 
TEST_F(HloDceTest,ControlDependencies)124 TEST_F(HloDceTest, ControlDependencies) {
125   // Verify that instructions with control dependencies are not removed.
126   auto builder = HloComputation::Builder(TestName());
127   auto constant1 = builder.AddInstruction(
128       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
129   auto constant2 = builder.AddInstruction(
130       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
131 
132   // Create two dead instructions: a negate and an add.
133   auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
134       constant1->shape(), HloOpcode::kNegate, constant1));
135   auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
136       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
137 
138   // Create the same two instructions again, but these will have a control
139   // dependency added.
140   auto dead_negate_with_control_dep =
141       builder.AddInstruction(HloInstruction::CreateUnary(
142           constant1->shape(), HloOpcode::kNegate, constant1));
143   auto dead_add_with_control_dep =
144       builder.AddInstruction(HloInstruction::CreateBinary(
145           constant1->shape(), HloOpcode::kAdd, constant1, constant2));
146 
147   // Create a root so the previously added instruction is dead.
148   builder.AddInstruction(HloInstruction::CreateBinary(
149       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
150 
151   auto module = CreateNewVerifiedModule();
152   auto computation = module->AddEntryComputation(builder.Build());
153 
154   // Add a control dependency between two instructions.
155   TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo(
156       dead_add_with_control_dep));
157 
158   EXPECT_EQ(7, computation->instruction_count());
159   EXPECT_TRUE(HasInstruction(*computation, dead_negate));
160   EXPECT_TRUE(HasInstruction(*computation, dead_add));
161   EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
162   EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
163 
164   HloDCE dce;
165   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
166 
167   EXPECT_EQ(5, computation->instruction_count());
168   EXPECT_FALSE(HasInstruction(*computation, dead_negate));
169   EXPECT_FALSE(HasInstruction(*computation, dead_add));
170   EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
171   EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
172 }
173 
174 // Tests that a dead call instruction is removed.
TEST_F(HloDceTest,DeadInstructionWithCalledComputation)175 TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
176   auto module = CreateNewVerifiedModule();
177   Shape shape = ShapeUtil::MakeShape(F32, {});
178 
179   // Called computation for the call instruction.
180   auto callee_builder = HloComputation::Builder(TestName() + "-callee");
181   {
182     auto param = callee_builder.AddInstruction(
183         HloInstruction::CreateParameter(0, shape, "param"));
184     callee_builder.AddInstruction(
185         HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
186   }
187   auto called_computation =
188       module->AddEmbeddedComputation(callee_builder.Build());
189 
190   // Entry computation with a call instruction.
191   auto builder = HloComputation::Builder(TestName());
192   auto param = builder.AddInstruction(
193       HloInstruction::CreateParameter(0, shape, "param"));
194   auto dead_call = builder.AddInstruction(
195       HloInstruction::CreateCall(shape, {param}, called_computation));
196   builder.AddInstruction(
197       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
198   auto computation = module->AddEntryComputation(builder.Build());
199 
200   EXPECT_EQ(3, computation->instruction_count());
201   EXPECT_EQ(2, param->user_count());
202   EXPECT_EQ(0, dead_call->user_count());
203   EXPECT_TRUE(HasInstruction(*computation, dead_call));
204 
205   HloDCE dce;
206   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
207 
208   EXPECT_EQ(2, computation->instruction_count());
209   EXPECT_EQ(1, param->user_count());
210   EXPECT_FALSE(HasInstruction(*computation, dead_call));
211 }
212 
213 // Tests that a while instruction with an infeed (effectul instruction) in its
214 // body is not removed, even its user count is 0.
TEST_F(HloDceTest,CalledComputationWithSideEffect)215 TEST_F(HloDceTest, CalledComputationWithSideEffect) {
216   auto module = CreateNewUnverifiedModule();
217   Shape shape = ShapeUtil::MakeShape(F32, {});
218 
219   // Condition computation of a while instruction.
220   auto cond_builder = HloComputation::Builder(TestName() + "-cond");
221   {
222     auto param = cond_builder.AddInstruction(
223         HloInstruction::CreateParameter(0, shape, "cond_param"));
224     auto constant = cond_builder.AddInstruction(
225         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
226     cond_builder.AddInstruction(
227         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
228                                       constant, ComparisonDirection::kLt));
229   }
230   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
231 
232   // Body computation of a while instruction.
233   auto body_builder = HloComputation::Builder(TestName() + "-body");
234   {
235     auto param = body_builder.AddInstruction(
236         HloInstruction::CreateParameter(0, shape, "param"));
237     auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
238     auto infeed = body_builder.AddInstruction(
239         HloInstruction::CreateInfeed(shape, token, ""));
240     body_builder.AddInstruction(
241         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
242   }
243   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
244 
245   // Entry computation with a while instruction and a negate on the parameter.
246   auto builder = HloComputation::Builder(TestName());
247   auto param = builder.AddInstruction(
248       HloInstruction::CreateParameter(0, shape, "param"));
249   auto live_while = builder.AddInstruction(HloInstruction::CreateWhile(
250       shape, cond_computation, body_computation, param));
251   builder.AddInstruction(
252       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
253   auto computation = module->AddEntryComputation(builder.Build());
254 
255   // Check the while instruction is not removed even if its user count is 0.
256   EXPECT_EQ(3, computation->instruction_count());
257   EXPECT_EQ(2, param->user_count());
258   EXPECT_EQ(0, live_while->user_count());
259   EXPECT_TRUE(HasInstruction(*computation, live_while));
260 
261   HloDCE dce;
262   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
263 
264   EXPECT_EQ(3, computation->instruction_count());
265   EXPECT_EQ(2, param->user_count());
266   EXPECT_EQ(0, live_while->user_count());
267   EXPECT_TRUE(HasInstruction(*computation, live_while));
268 }
269 
270 // Tests that a nested call instruction with a side effect is not removed.
TEST_F(HloDceTest,CalledComputationWithNestedSideEffect)271 TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
272   auto module = CreateNewUnverifiedModule();
273   Shape shape = ShapeUtil::MakeShape(F32, {});
274 
275   // Nested called computation with a side effect.
276   auto nested_callee_builder =
277       HloComputation::Builder(TestName() + "-nested_callee");
278   {
279     auto param = nested_callee_builder.AddInstruction(
280         HloInstruction::CreateParameter(0, shape, "param"));
281     auto token =
282         nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
283     nested_callee_builder.AddInstruction(
284         HloInstruction::CreateOutfeed(shape, param, token, ""));
285   }
286   auto nested_called_computation =
287       module->AddEmbeddedComputation(nested_callee_builder.Build());
288 
289   // Outer called computation that calls the nested computation.
290   auto callee_builder = HloComputation::Builder(TestName() + "-callee");
291   {
292     auto param = callee_builder.AddInstruction(
293         HloInstruction::CreateParameter(0, shape, "param"));
294     callee_builder.AddInstruction(
295         HloInstruction::CreateCall(shape, {param}, nested_called_computation));
296   }
297   auto called_computation =
298       module->AddEmbeddedComputation(callee_builder.Build());
299 
300   // Entry computation with a call instruction.
301   auto builder = HloComputation::Builder(TestName());
302   auto param = builder.AddInstruction(
303       HloInstruction::CreateParameter(0, shape, "param"));
304   auto live_call = builder.AddInstruction(
305       HloInstruction::CreateCall(shape, {param}, called_computation));
306   builder.AddInstruction(
307       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
308   auto computation = module->AddEntryComputation(builder.Build());
309 
310   EXPECT_EQ(3, computation->instruction_count());
311   EXPECT_EQ(2, param->user_count());
312   EXPECT_EQ(0, live_call->user_count());
313   EXPECT_TRUE(HasInstruction(*computation, live_call));
314 
315   HloDCE dce;
316   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
317 
318   EXPECT_EQ(3, computation->instruction_count());
319   EXPECT_EQ(2, param->user_count());
320   EXPECT_EQ(0, live_call->user_count());
321   EXPECT_TRUE(HasInstruction(*computation, live_call));
322 }
323 
TEST_F(HloDceTest,RemoveDeadSubcomputation)324 TEST_F(HloDceTest, RemoveDeadSubcomputation) {
325   auto module = CreateNewVerifiedModule();
326   HloComputation::Builder builder(TestName());
327 
328   HloComputation::Builder subcomp_builder("reduction_subcomp");
329   {
330     auto* param0 =
331         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
332             /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
333     auto* param1 =
334         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
335             /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
336     subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
337         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
338   }
339   auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
340 
341   // Create a dead reduce instruction.
342   builder.AddInstruction(HloInstruction::CreateReduce(
343       ShapeUtil::MakeShape(F32, {1}),
344       builder.AddInstruction(HloInstruction::CreateParameter(
345           /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
346       builder.AddInstruction(
347           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
348       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
349 
350   // Add another instruction as the root of the computation.
351   builder.AddInstruction(
352       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
353 
354   module->AddEntryComputation(builder.Build());
355   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
356 
357   HloDCE dce;
358   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
359 
360   // We should have DCE'ed the reduction computation along with the reduction
361   // instruction.
362   EXPECT_EQ(module->MakeComputationPostOrder().size(), 1);
363 }
364 
TEST_F(HloDceTest,KeepUsedSubcomputation)365 TEST_F(HloDceTest, KeepUsedSubcomputation) {
366   auto module = CreateNewUnverifiedModule();
367   HloComputation::Builder builder(TestName());
368 
369   HloComputation::Builder subcomp_builder("reduction_subcomp");
370   {
371     auto* param0 =
372         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
373             /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
374     auto* param1 =
375         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
376             /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
377     subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
378         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
379   }
380   auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
381 
382   // Create a dead reduce instruction.
383   builder.AddInstruction(HloInstruction::CreateReduce(
384       ShapeUtil::MakeShape(F32, {1}),
385       builder.AddInstruction(HloInstruction::CreateParameter(
386           /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
387       builder.AddInstruction(
388           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
389       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
390 
391   // Add another instruction as the root of the computation that also uses
392   // reduce_subcomp.
393   builder.AddInstruction(HloInstruction::CreateReduce(
394       ShapeUtil::MakeShape(F32, {1}),
395       builder.AddInstruction(HloInstruction::CreateParameter(
396           /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
397       builder.AddInstruction(
398           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
399       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
400 
401   module->AddEntryComputation(builder.Build());
402   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
403 
404   HloDCE dce;
405   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
406 
407   // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of
408   // its users.
409   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
410 }
411 
412 }  // namespace
413 }  // namespace xla
414