1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 
32 namespace xla {
33 namespace {
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 using ::testing::_;
38 
39 class HloRematerializationTest : public HloTestBase {
40  protected:
41   // Creates and returns a computation which can benefit from
42   // rematerialization. The computation looks like:
43   //
44   //   F32[1] %param = {...}
45   //   F32[] %reshape = reshape(F32[], param)
46   //   F32[1024] %bcast = broadcast(%param)
47   //   F32[1024] %negate = negate(%bcast)
48   //   F32[2048] %concat_1 = concat({%negate, %negate})
49   //   F32[1] %slice_1 = slice(%concat_1, {0:1})
50   //   F32[1025] %concat_2 = concat({%bcast, %slice_1})
51   //   F32[1] %slice_2 = slice(%concat_2, {0:1});
52   //
53   // The instruction %bcast can be rematerialized before its use at %concat_2
54   // to reduce peak memory usage. This avoids %bcast and %concat_1 being
55   // simultaneously live. Peak memory use is about 16KB before rematerialization
56   // (during execution of %concat_1) and about 12KB after rematerializing %bcast
57   // for its use in %concat_2.
MakeRematerializableComputation(const string & suffix="")58   std::unique_ptr<HloComputation> MakeRematerializableComputation(
59       const string& suffix = "") {
60     auto builder = HloComputation::Builder(TestName() + suffix);
61     auto param = builder.AddInstruction(
62         HloInstruction::CreateParameter(0, vec1_shape_, "param"));
63     auto reshape = builder.AddInstruction(
64         HloInstruction::CreateReshape(scalar_shape_, param));
65     auto bcast = builder.AddInstruction(
66         HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
67     auto negate = builder.AddInstruction(
68         HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
69     auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
70         ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate},
71         /*dimension=*/0));
72     auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
73         vec1_shape_, concat_1, /*start_indices=*/{0},
74         /*limit_indices=*/{1},
75         /*strides=*/{1}));
76     auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
77         ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
78         /*dimension=*/0));
79     // Add a final slice to make the parameter shape match the output shape
80     // which is necessary to use this computation in a while.
81     builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
82                                                        /*start_indices=*/{0},
83                                                        /*limit_indices=*/{1},
84                                                        /*strides=*/{1}));
85     return builder.Build();
86   }
87 
88   // Creates and returns a computation which includes a while and can benefit
89   // from rematerialization. The computation looks like:
90   //
91   //   F32[] %param = {...}
92   //   F32[1024] %bcast = broadcast(%param)
93   //   F32[1] %slice_1 = slice(%bcast, {0:1})
94   //   F32[1] %while = while(%slice_1, while_body, while_cond)
95   //   F32[1025] %concat = concat({%bcast, %while})
96   //   F32[1] %slice_2 = slice(%concat, {0:1});
97   //
98   // The instruction %bcast can be rematerialized before its use at %concat to
99   // reduce peak memory usage. This avoids %bcast being live during execution of
100   // the while. Peak memory use is maximum of 8K and 4K plus the memory use of
101   // the while subcomputations.
MakeRematerializableWhileComputation(HloComputation * while_cond,HloComputation * while_body,const string & suffix="")102   std::unique_ptr<HloComputation> MakeRematerializableWhileComputation(
103       HloComputation* while_cond, HloComputation* while_body,
104       const string& suffix = "") {
105     auto builder = HloComputation::Builder(TestName() + suffix);
106     auto param = builder.AddInstruction(
107         HloInstruction::CreateParameter(0, vec1_shape_, "param"));
108     auto reshape = builder.AddInstruction(
109         HloInstruction::CreateReshape(scalar_shape_, param));
110     auto bcast = builder.AddInstruction(
111         HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
112     auto slice_1 = builder.AddInstruction(
113         HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
114                                     /*limit_indices=*/{1},
115                                     /*strides=*/{1}));
116     auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
117         vec1_shape_, while_cond, while_body, slice_1));
118     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
119         ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst},
120         /*dimension=*/0));
121     builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
122                                                        /*start_indices=*/{0},
123                                                        /*limit_indices=*/{1},
124                                                        /*strides=*/{1}));
125     return builder.Build();
126   }
127 
128   // Create and return a trivial computation appropriate for use as a while
129   // condition.
MakeConditionComputation()130   std::unique_ptr<HloComputation> MakeConditionComputation() {
131     auto builder = HloComputation::Builder(TestName() + ".cond");
132     builder.AddInstruction(
133         HloInstruction::CreateParameter(0, vec1_shape_, "param"));
134     builder.AddInstruction(
135         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
136     return builder.Build();
137   }
138 
139   // Return the byte size of the top-level buffer of the given shape.
ByteSizeOf(const Shape & shape)140   static int64 ByteSizeOf(const Shape& shape) {
141     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
142   }
143 
RunHloRematerialization(int64 memory_limit_bytes,HloModule * module)144   StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
145                                          HloModule* module) {
146     TF_EXPECT_OK(verifier().Run(module).status());
147     HloMemoryScheduler scheduler(
148         [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
149         DefaultMemoryScheduler);
150     TF_EXPECT_OK(scheduler.Run(module).status());
151     HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
152                                /*sizes=*/nullptr);
153     return remat.Run(module);
154   }
155 
156   // Various shapes used in the canned computations.
157   const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
158   const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
159   const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024});
160 };
161 
162 // Test rematerialization of a single computation produced by
163 // MakeRematerializableComputation.
TEST_F(HloRematerializationTest,SingleComputation)164 TEST_F(HloRematerializationTest, SingleComputation) {
165   auto module = CreateNewVerifiedModule();
166   HloComputation* computation =
167       module->AddEntryComputation(MakeRematerializableComputation());
168 
169   // Find and save the original broadcast instruction which should be
170   // rematerialized.
171   const HloInstruction* slice = computation->root_instruction();
172   ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
173   const HloInstruction* concat = slice->operand(0);
174   const HloInstruction* bcast = concat->operand(0);
175 
176   // Computation requires 16KB without rematerialization, but uses only 12KB
177   // with rematerialization so pick a memory limit between these values (14KB).
178   TF_ASSERT_OK_AND_ASSIGN(bool changed,
179                           RunHloRematerialization(
180                               /*memory_limit_bytes=*/14 * 1024, module.get()));
181   EXPECT_TRUE(changed);
182 
183   // Root should not have changed.
184   EXPECT_EQ(computation->root_instruction(), slice);
185 
186   // The broadcast should have been rematerialized.
187   const HloInstruction* remat_bcast = concat->operand(0);
188   EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
189 
190   // The rematerialized broadcast should be immediate before the concat in the
191   // sequence.
192   EXPECT_EQ(module->schedule()
193                 .sequence(computation)
194                 .instructions()[computation->instruction_count() - 2],
195             concat);
196   EXPECT_EQ(module->schedule()
197                 .sequence(computation)
198                 .instructions()[computation->instruction_count() - 3],
199             remat_bcast);
200 }
201 
202 // Test rematerialization of a single computation produced by
203 // MakeRematerializableComputation but with a sufficiently high memory limit
204 // such that no instructions are rematerialized.
TEST_F(HloRematerializationTest,SingleComputationNoRematerialization)205 TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
206   auto module = CreateNewVerifiedModule();
207   HloComputation* computation =
208       module->AddEntryComputation(MakeRematerializableComputation());
209 
210   EXPECT_EQ(computation->instruction_count(), 8);
211 
212   TF_ASSERT_OK_AND_ASSIGN(bool changed,
213                           RunHloRematerialization(
214                               /*memory_limit_bytes=*/20 * 1024, module.get()));
215 
216   // No instructions should have been materialized.
217   EXPECT_FALSE(changed);
218   EXPECT_EQ(computation->instruction_count(), 8);
219 }
220 
221 // Test rematerialization of a computation which calls another computation via a
222 // while. Both the entry computation and while body computation can have memory
223 // usage reduced via rematerialization however the memory limit is set such that
224 // only one computation needs to have an instruction rematerialized. The entry
225 // computation should be the one chosen because rematerialization in the while
226 // will presumably be more expensive.
TEST_F(HloRematerializationTest,RematerializeAroundWhile)227 TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
228   auto module = CreateNewVerifiedModule();
229 
230   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
231   cond_builder.AddInstruction(
232       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
233   cond_builder.AddInstruction(
234       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
235   HloComputation* while_cond =
236       module->AddEmbeddedComputation(cond_builder.Build());
237 
238   HloComputation* body_computation = module->AddEmbeddedComputation(
239       MakeRematerializableComputation(/*suffix=*/".body"));
240   HloComputation* entry_computation =
241       module->AddEntryComputation(MakeRematerializableWhileComputation(
242           while_cond, /*while_body=*/body_computation));
243 
244   EXPECT_EQ(entry_computation->instruction_count(), 7);
245   EXPECT_EQ(body_computation->instruction_count(), 8);
246 
247   // The body computation uses 16KB and the entry computation uses 2KB at the
248   // while so the peak memory use of the module is 18KB. Set the memory limit a
249   // bit lower (17KB) to force rematerialization of the entry computation.
250   TF_ASSERT_OK_AND_ASSIGN(bool changed,
251                           RunHloRematerialization(
252                               /*memory_limit_bytes=*/17 * 1024, module.get()));
253   EXPECT_TRUE(changed);
254 
255   // Only the entry computation should have a rematerialized instruction added.
256   EXPECT_EQ(entry_computation->instruction_count(), 8);
257   EXPECT_EQ(body_computation->instruction_count(), 8);
258 }
259 
260 // Test rematerialization of a computation which calls another computation via a
261 // while. Both the entry computation and while body computation should have
262 // computations rematerialized.
TEST_F(HloRematerializationTest,RematerializeEntryAndWhileBody)263 TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
264   auto module = CreateNewVerifiedModule();
265 
266   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
267   cond_builder.AddInstruction(
268       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
269   cond_builder.AddInstruction(
270       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
271   HloComputation* while_cond =
272       module->AddEmbeddedComputation(cond_builder.Build());
273 
274   HloComputation* body_computation = module->AddEmbeddedComputation(
275       MakeRematerializableComputation(/*suffix=*/".body"));
276   HloComputation* entry_computation =
277       module->AddEntryComputation(MakeRematerializableWhileComputation(
278           while_cond, /*while_body=*/body_computation));
279 
280   EXPECT_EQ(entry_computation->instruction_count(), 7);
281   EXPECT_EQ(body_computation->instruction_count(), 8);
282 
283   TF_ASSERT_OK_AND_ASSIGN(bool changed,
284                           RunHloRematerialization(
285                               /*memory_limit_bytes=*/15 * 1024, module.get()));
286   EXPECT_TRUE(changed);
287 
288   // Both computations should have rematerialized instructions added.
289   EXPECT_EQ(entry_computation->instruction_count(), 9);
290   EXPECT_EQ(body_computation->instruction_count(), 9);
291 }
292 
293 // Test rematerialization of a doubly nested computation. All computations
294 // should have an instruction rematerialized.
TEST_F(HloRematerializationTest,RematerializeNestedComputations)295 TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
296   auto module = CreateNewVerifiedModule();
297 
298   auto cond_builder = HloComputation::Builder(TestName() + ".cond");
299   cond_builder.AddInstruction(
300       HloInstruction::CreateParameter(0, vec1_shape_, "param"));
301   cond_builder.AddInstruction(
302       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
303   HloComputation* while_cond =
304       module->AddEmbeddedComputation(cond_builder.Build());
305 
306   HloComputation* inner_computation = module->AddEmbeddedComputation(
307       MakeRematerializableComputation(/*suffix=*/".inner"));
308   HloComputation* middle_computation =
309       module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
310           while_cond, /*while_body=*/inner_computation,
311           /*suffix=*/".middle"));
312   HloComputation* entry_computation =
313       module->AddEntryComputation(MakeRematerializableWhileComputation(
314           while_cond, /*while_body=*/middle_computation));
315 
316   EXPECT_EQ(entry_computation->instruction_count(), 7);
317   EXPECT_EQ(middle_computation->instruction_count(), 7);
318   EXPECT_EQ(inner_computation->instruction_count(), 8);
319 
320   // If all computations are maximally rematerialized then peak memory usage is
321   // ~12K so pick something slightly larger.
322   TF_ASSERT_OK_AND_ASSIGN(bool changed,
323                           RunHloRematerialization(
324                               /*memory_limit_bytes=*/13 * 1024, module.get()));
325   EXPECT_TRUE(changed);
326 
327   // All computations should have rematerialized instructions added.
328   EXPECT_EQ(entry_computation->instruction_count(), 9);
329   EXPECT_EQ(middle_computation->instruction_count(), 9);
330   EXPECT_EQ(inner_computation->instruction_count(), 9);
331 }
332 
TEST_F(HloRematerializationTest,RngNotRematerialized)333 TEST_F(HloRematerializationTest, RngNotRematerialized) {
334   // Test that a single rng is not rematerialized:
335   //
336   // Entry computation:
337   //   F32[] %param = {...}
338   //   F32[1024] rng = rng(param)
339   //   F32[1024] tanh = tanh(rng)
340   //   F32[1024] exp = exp(rng)
341   //   F32[1024] add_0 = add(rng, tanh)              // LIVE: add_0 + rng +
342   //                                                 //       tanh + exp
343   //
344   //   F32[1024] add_1 = add(rng, add(exp, add_0))   // LIVE: add_1 + add_0 +
345   //                                                 //       rng + tanh + exp
346   //
347   //   F32[1024] add_2 = add(rng, add(tanh, add_1))  // LIVE: add_2 + add_1 +
348   //                                                 //       rng + tanh + exp
349   auto module = CreateNewVerifiedModule();
350 
351   auto builder = HloComputation::Builder(TestName());
352   auto param = builder.AddInstruction(
353       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
354   auto rng = builder.AddInstruction(HloInstruction::CreateRng(
355       vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param}));
356   auto tanh = builder.AddInstruction(
357       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng));
358   auto exp = builder.AddInstruction(
359       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng));
360   auto add_0 = builder.AddInstruction(
361       HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh));
362   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
363       vec1024_shape_, HloOpcode::kAdd, rng,
364       builder.AddInstruction(HloInstruction::CreateBinary(
365           vec1024_shape_, HloOpcode::kAdd, exp, add_0))));
366   builder.AddInstruction(HloInstruction::CreateBinary(
367       vec1024_shape_, HloOpcode::kAdd, rng,
368       builder.AddInstruction(HloInstruction::CreateBinary(
369           vec1024_shape_, HloOpcode::kAdd, tanh, add_1))));
370   HloComputation* entry_computation =
371       module->AddEntryComputation(builder.Build());
372 
373   auto count_rngs = [](const HloComputation* computation) {
374     int64 rng_count = 0;
375     for (auto* instruction : computation->instructions()) {
376       if (instruction->opcode() == HloOpcode::kRng) {
377         ++rng_count;
378       }
379     }
380     return rng_count;
381   };
382   // Before rematerialization there should be a single broadcast rng in
383   // the graph.
384   ASSERT_EQ(count_rngs(entry_computation), 1);
385   const int64 original_instruction_count =
386       entry_computation->instruction_count();
387   // Pick a memory limit some where between 24KB (initial peak memory including
388   // parameter and output) and 20KB (peak memory possible with
389   // rematerialization).
390   TF_ASSERT_OK_AND_ASSIGN(
391       bool changed,
392       RunHloRematerialization(
393           /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
394   EXPECT_TRUE(changed);
395   // The rng should not have been rematerialized.
396   EXPECT_EQ(count_rngs(entry_computation), 1);
397   // There should have been rematerialization.
398   EXPECT_GT(entry_computation->instruction_count(), original_instruction_count);
399 }
400 
TEST_F(HloRematerializationTest,InstructionRematerializedMultipleTimes)401 TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
402   // Test that a single instruction is rematerialized several times. Module:
403   //
404   // Entry computation:
405   //   F32[] %param = {...}
406   //   F32[1024] %bcast = broadcast(%param)
407   //   F32[1024] %add_1 = add(%bcast, bcast)
408   //   F32[1024] %call_1 = call(Subcomputation, {%add_1})
409   //   F32[1024] %add_2 = add(%bcast, call_1)
410   //   F32[1024] %call_2 = call(SubComputation, {%add_2})
411   //   F32[1024] %add_3 = add(%bcast, call_2)
412   //   F32[1024] %call_3 = call(Subcomputation, {%add_3})
413   //   F32[1024] %add_4 = add(%bcast, call_3)
414   //
415   // Subcomputation:
416   //   F32[1024] %param = {...}
417   //   F32[2048] %concat = concat({%param, %param})
418   //   F32[1024] %slice = slice(%concat)
419   //
420   // The value %bcast is live across each call of Subcomputation (which requires
421   // 8KB) though the value is not used in the calls. Rematerializing %bcast
422   // across these calls reduces peak memory use from ~20KB down to ~16KB.
423   auto module = CreateNewVerifiedModule();
424 
425   HloComputation* subcomputation = nullptr;
426   {
427     auto builder = HloComputation::Builder(TestName() + ".subcomputation");
428     auto param = builder.AddInstruction(
429         HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
430     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
431         ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
432         /*dimension=*/0));
433     builder.AddInstruction(HloInstruction::CreateSlice(
434         vec1024_shape_, concat, /*start_indices=*/{0},
435         /*limit_indices=*/{1024}, /*strides=*/{1}));
436     subcomputation = module->AddEmbeddedComputation(builder.Build());
437   }
438 
439   auto builder = HloComputation::Builder(TestName());
440   auto param = builder.AddInstruction(
441       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
442   auto bcast = builder.AddInstruction(
443       HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
444   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
445       vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
446   auto call_1 = builder.AddInstruction(
447       HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
448   auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
449       vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
450   auto call_2 = builder.AddInstruction(
451       HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation));
452   auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary(
453       vec1024_shape_, HloOpcode::kAdd, bcast, call_2));
454   auto call_3 = builder.AddInstruction(
455       HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation));
456   auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
457       vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
458   HloComputation* entry_computation =
459       module->AddEntryComputation(builder.Build());
460 
461   auto count_broadcasts = [](const HloComputation* computation) {
462     int64 bcast_count = 0;
463     for (auto* instruction : computation->instructions()) {
464       if (instruction->opcode() == HloOpcode::kBroadcast) {
465         bcast_count++;
466       }
467     }
468     return bcast_count;
469   };
470 
471   // Before rematerialization there should be a single broadcast instruction in
472   // the graph.
473   EXPECT_EQ(count_broadcasts(entry_computation), 1);
474   EXPECT_EQ(entry_computation->instruction_count(), 9);
475 
476   EXPECT_EQ(add_2->operand(0), bcast);
477   EXPECT_EQ(add_3->operand(0), bcast);
478   EXPECT_EQ(add_4->operand(0), bcast);
479 
480   // Pick a memory limit some where between 24KB (initial peak memory including
481   // parameter and output) and 20KB (peak memory possible with
482   // rematerialization).
483   TF_ASSERT_OK_AND_ASSIGN(bool changed,
484                           RunHloRematerialization(
485                               /*memory_limit_bytes=*/22 * 1024, module.get()));
486   EXPECT_TRUE(changed);
487 
488   // The broadcast should have been rematerialized 3 times.
489   EXPECT_EQ(count_broadcasts(entry_computation), 4);
490   EXPECT_EQ(entry_computation->instruction_count(), 12);
491 
492   // The operands of add_2, add_3, and add_4 should all be rematerialized
493   // broadcasts.
494   EXPECT_NE(add_2->operand(0), bcast);
495   EXPECT_THAT(add_2->operand(0), op::Broadcast(param));
496   EXPECT_NE(add_3->operand(0), bcast);
497   EXPECT_THAT(add_3->operand(0), op::Broadcast(param));
498   EXPECT_NE(add_4->operand(0), bcast);
499   EXPECT_THAT(add_4->operand(0), op::Broadcast(param));
500 }
501 
TEST_F(HloRematerializationTest,CopyNotRematerialized)502 TEST_F(HloRematerializationTest, CopyNotRematerialized) {
503   // Test that copies are not rematerialized.
504   auto module = CreateNewVerifiedModule();
505 
506   auto builder = HloComputation::Builder(TestName());
507   auto param = builder.AddInstruction(
508       HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
509 
510   auto copy = builder.AddInstruction(
511       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param));
512 
513   auto negate_a_1 = builder.AddInstruction(
514       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
515 
516   auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary(
517       vec1024_shape_, HloOpcode::kNegate, negate_a_1));
518 
519   auto negate_b_1 = builder.AddInstruction(
520       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
521 
522   auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary(
523       vec1024_shape_, HloOpcode::kNegate, negate_b_1));
524 
525   builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2}));
526 
527   HloComputation* entry_computation =
528       module->AddEntryComputation(builder.Build());
529 
530   TF_ASSERT_OK_AND_ASSIGN(bool changed,
531                           RunHloRematerialization(
532                               /*memory_limit_bytes=*/1 * 1024, module.get()));
533 
534   auto count_copies = [](const HloComputation* computation) {
535     int64 copy_count = 0;
536     for (auto* instruction : computation->instructions()) {
537       if (instruction->opcode() == HloOpcode::kCopy) {
538         copy_count++;
539       }
540     }
541     return copy_count;
542   };
543   EXPECT_TRUE(changed);
544 
545   EXPECT_EQ(count_copies(entry_computation), 1);
546 }
547 
548 class IndirectUseTest : public HloRematerializationTest,
549                         public ::testing::WithParamInterface<bool> {};
550 
TEST_P(IndirectUseTest,IndirectUseNotRematerialized)551 TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
552   // Test that an rematerializable instruction is not rematerialized if it has
553   // an indirect use. Test is parameterized on whether the value has an indirect
554   // use, and the instruction should be rematerialized iff the value has no
555   // indirect use. Module:
556   //
557   // Entry computation:
558   //   F32[] %param = {...}
559   //   F32[1024] %bcast = broadcast(%param)
560   //   F32[1024] %add_1 = add(%bcast, bcast)
561   //   F32[1024] %call = call(Subcomputation, {%add_1})
562   //   F32[1024] %add_2 = add(%bcast, call)
563   //   {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2)
564   //   F32[1024] %gte = GetTupleElememt(%tuple, 0)
565   //   F32[1024] %negate = negate(%gte)
566   //
567   // Subcomputation:
568   //   F32[1024] %param = {...}
569   //   F32[2048] %concat = concat({%param, %param})
570   //   F32[1024] %slice = slice(%concat)
571   //
572   // The value %bcast is live across the call and rematerialization of %bcast
573   // across that point would reduce peak memory use by 4KB. However, %bcast is
574   // used indirectly in the %negate so rematerialization should not happen.
575   //
576   // This test is parameterized on whether the broadcast has an indirect use or
577   // not. The indirect use is controlled by the index of the GetTupleElement
578   // instruction. If the element is 0, then the %negate operand aliases %bcast
579   // (ie %bcast is used indirectly by %negate), otherwise the %negate operand
580   // aliases %add_2.
581   const bool indirectly_used = GetParam();
582   auto module = CreateNewVerifiedModule();
583 
584   HloComputation* subcomputation = nullptr;
585   {
586     auto builder = HloComputation::Builder(TestName() + ".subcomputation");
587     auto param = builder.AddInstruction(
588         HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
589     auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
590         ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
591         /*dimension=*/0));
592     builder.AddInstruction(HloInstruction::CreateSlice(
593         vec1024_shape_, concat, /*start_indices=*/{0},
594         /*limit_indices=*/{1024}, /*strides=*/{1}));
595     subcomputation = module->AddEmbeddedComputation(builder.Build());
596   }
597 
598   auto builder = HloComputation::Builder(TestName());
599   auto param = builder.AddInstruction(
600       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
601   auto bcast = builder.AddInstruction(
602       HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
603   auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
604       vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
605   auto call_1 = builder.AddInstruction(
606       HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
607   auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
608       vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
609   auto tuple =
610       builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2}));
611   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
612       vec1024_shape_, tuple, indirectly_used ? 0 : 1));
613   builder.AddInstruction(
614       HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
615   HloComputation* entry_computation =
616       module->AddEntryComputation(builder.Build());
617 
618   EXPECT_EQ(entry_computation->instruction_count(), 8);
619 
620   // Pick a memory limit some where between 24KB (initial peak memory including
621   // parameter and output) and 20KB (peak memory possible with
622   // rematerialization).
623   TF_ASSERT_OK_AND_ASSIGN(bool changed,
624                           RunHloRematerialization(
625                               /*memory_limit_bytes=*/22 * 1024, module.get()));
626   // Rematerialization should only occur if the rematerializable instruction has
627   // no indirect uses.
628   if (indirectly_used) {
629     EXPECT_FALSE(changed);
630     EXPECT_EQ(entry_computation->instruction_count(), 8);
631   } else {
632     EXPECT_TRUE(changed);
633     EXPECT_EQ(entry_computation->instruction_count(), 9);
634   }
635 }
636 
637 INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
638                          ::testing::Values(true, false));
639 
640 }  // namespace
641 
642 }  // namespace xla
643