• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2  
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6  
7      http://www.apache.org/licenses/LICENSE-2.0
8  
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15  
16  #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17  
18  #include <memory>
19  #include <string>
20  
21  #include "tensorflow/compiler/xla/service/hlo_computation.h"
22  #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23  #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24  #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25  #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26  #include "tensorflow/compiler/xla/shape_util.h"
27  #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28  #include "tensorflow/compiler/xla/types.h"
29  #include "tensorflow/compiler/xla/xla_data.pb.h"
30  #include "tensorflow/core/lib/core/status_test_util.h"
31  
32  namespace xla {
33  namespace {
34  
35  namespace op = xla::testing::opcode_matchers;
36  
37  using ::testing::_;
38  
39  class HloRematerializationTest : public HloTestBase {
40   protected:
41    // Creates and returns a computation which can benefit from
42    // rematerialization. The computation looks like:
43    //
44    //   F32[1] %param = {...}
45    //   F32[] %reshape = reshape(F32[], param)
46    //   F32[1024] %bcast = broadcast(%param)
47    //   F32[1024] %negate = negate(%bcast)
48    //   F32[2048] %concat_1 = concat({%negate, %negate})
49    //   F32[1] %slice_1 = slice(%concat_1, {0:1})
50    //   F32[1025] %concat_2 = concat({%bcast, %slice_1})
51    //   F32[1] %slice_2 = slice(%concat_2, {0:1});
52    //
53    // The instruction %bcast can be rematerialized before its use at %concat_2
54    // to reduce peak memory usage. This avoids %bcast and %concat_1 being
55    // simultaneously live. Peak memory use is about 16KB before rematerialization
56    // (during execution of %concat_1) and about 12KB after rematerializing %bcast
57    // for its use in %concat_2.
MakeRematerializableComputation(const string & suffix="")58    std::unique_ptr<HloComputation> MakeRematerializableComputation(
59        const string& suffix = "") {
60      auto builder = HloComputation::Builder(TestName() + suffix);
61      auto param = builder.AddInstruction(
62          HloInstruction::CreateParameter(0, vec1_shape_, "param"));
63      auto reshape = builder.AddInstruction(
64          HloInstruction::CreateReshape(scalar_shape_, param));
65      auto bcast = builder.AddInstruction(
66          HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
67      auto negate = builder.AddInstruction(
68          HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
69      auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
70          ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate},
71          /*dimension=*/0));
72      auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
73          vec1_shape_, concat_1, /*start_indices=*/{0},
74          /*limit_indices=*/{1},
75          /*strides=*/{1}));
76      auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
77          ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
78          /*dimension=*/0));
79      // Add a final slice to make the parameter shape match the output shape
80      // which is necessary to use this computation in a while.
81      builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
82                                                         /*start_indices=*/{0},
83                                                         /*limit_indices=*/{1},
84                                                         /*strides=*/{1}));
85      return builder.Build();
86    }
87  
88    // Creates and returns a computation which includes a while and can benefit
89    // from rematerialization. The computation looks like:
90    //
91    //   F32[] %param = {...}
92    //   F32[1024] %bcast = broadcast(%param)
93    //   F32[1] %slice_1 = slice(%bcast, {0:1})
94    //   F32[1] %while = while(%slice_1, while_body, while_cond)
95    //   F32[1025] %concat = concat({%bcast, %while})
96    //   F32[1] %slice_2 = slice(%concat, {0:1});
97    //
98    // The instruction %bcast can be rematerialized before its use at %concat to
99    // reduce peak memory usage. This avoids %bcast being live during execution of
100    // the while. Peak memory use is maximum of 8K and 4K plus the memory use of
101    // the while subcomputations.
MakeRematerializableWhileComputation(HloComputation * while_cond,HloComputation * while_body,const string & suffix="")102    std::unique_ptr<HloComputation> MakeRematerializableWhileComputation(
103        HloComputation* while_cond, HloComputation* while_body,
104        const string& suffix = "") {
105      auto builder = HloComputation::Builder(TestName() + suffix);
106      auto param = builder.AddInstruction(
107          HloInstruction::CreateParameter(0, vec1_shape_, "param"));
108      auto reshape = builder.AddInstruction(
109          HloInstruction::CreateReshape(scalar_shape_, param));
110      auto bcast = builder.AddInstruction(
111          HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
112      auto slice_1 = builder.AddInstruction(
113          HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
114                                      /*limit_indices=*/{1},
115                                      /*strides=*/{1}));
116      auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
117          vec1_shape_, while_cond, while_body, slice_1));
118      auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
119          ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst},
120          /*dimension=*/0));
121      builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
122                                                         /*start_indices=*/{0},
123                                                         /*limit_indices=*/{1},
124                                                         /*strides=*/{1}));
125      return builder.Build();
126    }
127  
128    // Create and return a trivial computation appropriate for use as a while
129    // condition.
MakeConditionComputation()130    std::unique_ptr<HloComputation> MakeConditionComputation() {
131      auto builder = HloComputation::Builder(TestName() + ".cond");
132      builder.AddInstruction(
133          HloInstruction::CreateParameter(0, vec1_shape_, "param"));
134      builder.AddInstruction(
135          HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
136      return builder.Build();
137    }
138  
139    // Return the byte size of the top-level buffer of the given shape.
ByteSizeOf(const Shape & shape)140    static int64 ByteSizeOf(const Shape& shape) {
141      return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
142    }
143  
RunHloRematerialization(int64 memory_limit_bytes,HloModule * module)144    StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
145                                           HloModule* module) {
146      TF_EXPECT_OK(verifier().Run(module).status());
147      HloMemoryScheduler scheduler(
148          [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
149          DefaultMemoryScheduler);
150      TF_EXPECT_OK(scheduler.Run(module).status());
151      HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
152                                 /*sizes=*/nullptr);
153      return remat.Run(module);
154    }
155  
156    // Various shapes used in the canned computations.
157    const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
158    const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
159    const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024});
160  };
161  
162  // Test rematerialization of a single computation produced by
163  // MakeRematerializableComputation.
TEST_F(HloRematerializationTest,SingleComputation)164  TEST_F(HloRematerializationTest, SingleComputation) {
165    auto module = CreateNewVerifiedModule();
166    HloComputation* computation =
167        module->AddEntryComputation(MakeRematerializableComputation());
168  
169    // Find and save the original broadcast instruction which should be
170    // rematerialized.
171    const HloInstruction* slice = computation->root_instruction();
172    ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
173    const HloInstruction* concat = slice->operand(0);
174    const HloInstruction* bcast = concat->operand(0);
175  
176    // Computation requires 16KB without rematerialization, but uses only 12KB
177    // with rematerialization so pick a memory limit between these values (14KB).
178    TF_ASSERT_OK_AND_ASSIGN(bool changed,
179                            RunHloRematerialization(
180                                /*memory_limit_bytes=*/14 * 1024, module.get()));
181    EXPECT_TRUE(changed);
182  
183    // Root should not have changed.
184    EXPECT_EQ(computation->root_instruction(), slice);
185  
186    // The broadcast should have been rematerialized.
187    const HloInstruction* remat_bcast = concat->operand(0);
188    EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
189  
190    // The rematerialized broadcast should be immediate before the concat in the
191    // sequence.
192    EXPECT_EQ(module->schedule()
193                  .sequence(computation)
194                  .instructions()[computation->instruction_count() - 2],
195              concat);
196    EXPECT_EQ(module->schedule()
197                  .sequence(computation)
198                  .instructions()[computation->instruction_count() - 3],
199              remat_bcast);
200  }
201  
202  // Test rematerialization of a single computation produced by
203  // MakeRematerializableComputation but with a sufficiently high memory limit
204  // such that no instructions are rematerialized.
TEST_F(HloRematerializationTest,SingleComputationNoRematerialization)205  TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
206    auto module = CreateNewVerifiedModule();
207    HloComputation* computation =
208        module->AddEntryComputation(MakeRematerializableComputation());
209  
210    EXPECT_EQ(computation->instruction_count(), 8);
211  
212    TF_ASSERT_OK_AND_ASSIGN(bool changed,
213                            RunHloRematerialization(
214                                /*memory_limit_bytes=*/20 * 1024, module.get()));
215  
216    // No instructions should have been materialized.
217    EXPECT_FALSE(changed);
218    EXPECT_EQ(computation->instruction_count(), 8);
219  }
220  
221  // Test rematerialization of a computation which calls another computation via a
222  // while. Both the entry computation and while body computation can have memory
223  // usage reduced via rematerialization however the memory limit is set such that
224  // only one computation needs to have an instruction rematerialized. The entry
225  // computation should be the one chosen because rematerialization in the while
226  // will presumably be more expensive.
TEST_F(HloRematerializationTest,RematerializeAroundWhile)227  TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
228    auto module = CreateNewVerifiedModule();
229  
230    auto cond_builder = HloComputation::Builder(TestName() + ".cond");
231    cond_builder.AddInstruction(
232        HloInstruction::CreateParameter(0, vec1_shape_, "param"));
233    cond_builder.AddInstruction(
234        HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
235    HloComputation* while_cond =
236        module->AddEmbeddedComputation(cond_builder.Build());
237  
238    HloComputation* body_computation = module->AddEmbeddedComputation(
239        MakeRematerializableComputation(/*suffix=*/".body"));
240    HloComputation* entry_computation =
241        module->AddEntryComputation(MakeRematerializableWhileComputation(
242            while_cond, /*while_body=*/body_computation));
243  
244    EXPECT_EQ(entry_computation->instruction_count(), 7);
245    EXPECT_EQ(body_computation->instruction_count(), 8);
246  
247    // The body computation uses 16KB and the entry computation uses 2KB at the
248    // while so the peak memory use of the module is 18KB. Set the memory limit a
249    // bit lower (17KB) to force rematerialization of the entry computation.
250    TF_ASSERT_OK_AND_ASSIGN(bool changed,
251                            RunHloRematerialization(
252                                /*memory_limit_bytes=*/17 * 1024, module.get()));
253    EXPECT_TRUE(changed);
254  
255    // Only the entry computation should have a rematerialized instruction added.
256    EXPECT_EQ(entry_computation->instruction_count(), 8);
257    EXPECT_EQ(body_computation->instruction_count(), 8);
258  }
259  
260  // Test rematerialization of a computation which calls another computation via a
261  // while. Both the entry computation and while body computation should have
262  // computations rematerialized.
TEST_F(HloRematerializationTest,RematerializeEntryAndWhileBody)263  TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
264    auto module = CreateNewVerifiedModule();
265  
266    auto cond_builder = HloComputation::Builder(TestName() + ".cond");
267    cond_builder.AddInstruction(
268        HloInstruction::CreateParameter(0, vec1_shape_, "param"));
269    cond_builder.AddInstruction(
270        HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
271    HloComputation* while_cond =
272        module->AddEmbeddedComputation(cond_builder.Build());
273  
274    HloComputation* body_computation = module->AddEmbeddedComputation(
275        MakeRematerializableComputation(/*suffix=*/".body"));
276    HloComputation* entry_computation =
277        module->AddEntryComputation(MakeRematerializableWhileComputation(
278            while_cond, /*while_body=*/body_computation));
279  
280    EXPECT_EQ(entry_computation->instruction_count(), 7);
281    EXPECT_EQ(body_computation->instruction_count(), 8);
282  
283    TF_ASSERT_OK_AND_ASSIGN(bool changed,
284                            RunHloRematerialization(
285                                /*memory_limit_bytes=*/15 * 1024, module.get()));
286    EXPECT_TRUE(changed);
287  
288    // Both computations should have rematerialized instructions added.
289    EXPECT_EQ(entry_computation->instruction_count(), 9);
290    EXPECT_EQ(body_computation->instruction_count(), 9);
291  }
292  
293  // Test rematerialization of a doubly nested computation. All computations
294  // should have an instruction rematerialized.
TEST_F(HloRematerializationTest,RematerializeNestedComputations)295  TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
296    auto module = CreateNewVerifiedModule();
297  
298    auto cond_builder = HloComputation::Builder(TestName() + ".cond");
299    cond_builder.AddInstruction(
300        HloInstruction::CreateParameter(0, vec1_shape_, "param"));
301    cond_builder.AddInstruction(
302        HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
303    HloComputation* while_cond =
304        module->AddEmbeddedComputation(cond_builder.Build());
305  
306    HloComputation* inner_computation = module->AddEmbeddedComputation(
307        MakeRematerializableComputation(/*suffix=*/".inner"));
308    HloComputation* middle_computation =
309        module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
310            while_cond, /*while_body=*/inner_computation,
311            /*suffix=*/".middle"));
312    HloComputation* entry_computation =
313        module->AddEntryComputation(MakeRematerializableWhileComputation(
314            while_cond, /*while_body=*/middle_computation));
315  
316    EXPECT_EQ(entry_computation->instruction_count(), 7);
317    EXPECT_EQ(middle_computation->instruction_count(), 7);
318    EXPECT_EQ(inner_computation->instruction_count(), 8);
319  
320    // If all computations are maximally rematerialized then peak memory usage is
321    // ~12K so pick something slightly larger.
322    TF_ASSERT_OK_AND_ASSIGN(bool changed,
323                            RunHloRematerialization(
324                                /*memory_limit_bytes=*/13 * 1024, module.get()));
325    EXPECT_TRUE(changed);
326  
327    // All computations should have rematerialized instructions added.
328    EXPECT_EQ(entry_computation->instruction_count(), 9);
329    EXPECT_EQ(middle_computation->instruction_count(), 9);
330    EXPECT_EQ(inner_computation->instruction_count(), 9);
331  }
332  
TEST_F(HloRematerializationTest,RngNotRematerialized)333  TEST_F(HloRematerializationTest, RngNotRematerialized) {
334    // Test that a single rng is not rematerialized:
335    //
336    // Entry computation:
337    //   F32[] %param = {...}
338    //   F32[1024] rng = rng(param)
339    //   F32[1024] tanh = tanh(rng)
340    //   F32[1024] exp = exp(rng)
341    //   F32[1024] add_0 = add(rng, tanh)              // LIVE: add_0 + rng +
342    //                                                 //       tanh + exp
343    //
344    //   F32[1024] add_1 = add(rng, add(exp, add_0))   // LIVE: add_1 + add_0 +
345    //                                                 //       rng + tanh + exp
346    //
347    //   F32[1024] add_2 = add(rng, add(tanh, add_1))  // LIVE: add_2 + add_1 +
348    //                                                 //       rng + tanh + exp
349    auto module = CreateNewVerifiedModule();
350  
351    auto builder = HloComputation::Builder(TestName());
352    auto param = builder.AddInstruction(
353        HloInstruction::CreateParameter(0, scalar_shape_, "param"));
354    auto rng = builder.AddInstruction(HloInstruction::CreateRng(
355        vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param}));
356    auto tanh = builder.AddInstruction(
357        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng));
358    auto exp = builder.AddInstruction(
359        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng));
360    auto add_0 = builder.AddInstruction(
361        HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh));
362    auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
363        vec1024_shape_, HloOpcode::kAdd, rng,
364        builder.AddInstruction(HloInstruction::CreateBinary(
365            vec1024_shape_, HloOpcode::kAdd, exp, add_0))));
366    builder.AddInstruction(HloInstruction::CreateBinary(
367        vec1024_shape_, HloOpcode::kAdd, rng,
368        builder.AddInstruction(HloInstruction::CreateBinary(
369            vec1024_shape_, HloOpcode::kAdd, tanh, add_1))));
370    HloComputation* entry_computation =
371        module->AddEntryComputation(builder.Build());
372  
373    auto count_rngs = [](const HloComputation* computation) {
374      int64 rng_count = 0;
375      for (auto* instruction : computation->instructions()) {
376        if (instruction->opcode() == HloOpcode::kRng) {
377          ++rng_count;
378        }
379      }
380      return rng_count;
381    };
382    // Before rematerialization there should be a single broadcast rng in
383    // the graph.
384    ASSERT_EQ(count_rngs(entry_computation), 1);
385    const int64 original_instruction_count =
386        entry_computation->instruction_count();
387    // Pick a memory limit some where between 24KB (initial peak memory including
388    // parameter and output) and 20KB (peak memory possible with
389    // rematerialization).
390    TF_ASSERT_OK_AND_ASSIGN(
391        bool changed,
392        RunHloRematerialization(
393            /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
394    EXPECT_TRUE(changed);
395    // The rng should not have been rematerialized.
396    EXPECT_EQ(count_rngs(entry_computation), 1);
397    // There should have been rematerialization.
398    EXPECT_GT(entry_computation->instruction_count(), original_instruction_count);
399  }
400  
TEST_F(HloRematerializationTest,InstructionRematerializedMultipleTimes)401  TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
402    // Test that a single instruction is rematerialized several times. Module:
403    //
404    // Entry computation:
405    //   F32[] %param = {...}
406    //   F32[1024] %bcast = broadcast(%param)
407    //   F32[1024] %add_1 = add(%bcast, bcast)
408    //   F32[1024] %call_1 = call(Subcomputation, {%add_1})
409    //   F32[1024] %add_2 = add(%bcast, call_1)
410    //   F32[1024] %call_2 = call(SubComputation, {%add_2})
411    //   F32[1024] %add_3 = add(%bcast, call_2)
412    //   F32[1024] %call_3 = call(Subcomputation, {%add_3})
413    //   F32[1024] %add_4 = add(%bcast, call_3)
414    //
415    // Subcomputation:
416    //   F32[1024] %param = {...}
417    //   F32[2048] %concat = concat({%param, %param})
418    //   F32[1024] %slice = slice(%concat)
419    //
420    // The value %bcast is live across each call of Subcomputation (which requires
421    // 8KB) though the value is not used in the calls. Rematerializing %bcast
422    // across these calls reduces peak memory use from ~20KB down to ~16KB.
423    auto module = CreateNewVerifiedModule();
424  
425    HloComputation* subcomputation = nullptr;
426    {
427      auto builder = HloComputation::Builder(TestName() + ".subcomputation");
428      auto param = builder.AddInstruction(
429          HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
430      auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
431          ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
432          /*dimension=*/0));
433      builder.AddInstruction(HloInstruction::CreateSlice(
434          vec1024_shape_, concat, /*start_indices=*/{0},
435          /*limit_indices=*/{1024}, /*strides=*/{1}));
436      subcomputation = module->AddEmbeddedComputation(builder.Build());
437    }
438  
439    auto builder = HloComputation::Builder(TestName());
440    auto param = builder.AddInstruction(
441        HloInstruction::CreateParameter(0, scalar_shape_, "param"));
442    auto bcast = builder.AddInstruction(
443        HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
444    auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
445        vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
446    auto call_1 = builder.AddInstruction(
447        HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
448    auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
449        vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
450    auto call_2 = builder.AddInstruction(
451        HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation));
452    auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary(
453        vec1024_shape_, HloOpcode::kAdd, bcast, call_2));
454    auto call_3 = builder.AddInstruction(
455        HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation));
456    auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
457        vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
458    HloComputation* entry_computation =
459        module->AddEntryComputation(builder.Build());
460  
461    auto count_broadcasts = [](const HloComputation* computation) {
462      int64 bcast_count = 0;
463      for (auto* instruction : computation->instructions()) {
464        if (instruction->opcode() == HloOpcode::kBroadcast) {
465          bcast_count++;
466        }
467      }
468      return bcast_count;
469    };
470  
471    // Before rematerialization there should be a single broadcast instruction in
472    // the graph.
473    EXPECT_EQ(count_broadcasts(entry_computation), 1);
474    EXPECT_EQ(entry_computation->instruction_count(), 9);
475  
476    EXPECT_EQ(add_2->operand(0), bcast);
477    EXPECT_EQ(add_3->operand(0), bcast);
478    EXPECT_EQ(add_4->operand(0), bcast);
479  
480    // Pick a memory limit some where between 24KB (initial peak memory including
481    // parameter and output) and 20KB (peak memory possible with
482    // rematerialization).
483    TF_ASSERT_OK_AND_ASSIGN(bool changed,
484                            RunHloRematerialization(
485                                /*memory_limit_bytes=*/22 * 1024, module.get()));
486    EXPECT_TRUE(changed);
487  
488    // The broadcast should have been rematerialized 3 times.
489    EXPECT_EQ(count_broadcasts(entry_computation), 4);
490    EXPECT_EQ(entry_computation->instruction_count(), 12);
491  
492    // The operands of add_2, add_3, and add_4 should all be rematerialized
493    // broadcasts.
494    EXPECT_NE(add_2->operand(0), bcast);
495    EXPECT_THAT(add_2->operand(0), op::Broadcast(param));
496    EXPECT_NE(add_3->operand(0), bcast);
497    EXPECT_THAT(add_3->operand(0), op::Broadcast(param));
498    EXPECT_NE(add_4->operand(0), bcast);
499    EXPECT_THAT(add_4->operand(0), op::Broadcast(param));
500  }
501  
TEST_F(HloRematerializationTest,CopyNotRematerialized)502  TEST_F(HloRematerializationTest, CopyNotRematerialized) {
503    // Test that copies are not rematerialized.
504    auto module = CreateNewVerifiedModule();
505  
506    auto builder = HloComputation::Builder(TestName());
507    auto param = builder.AddInstruction(
508        HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
509  
510    auto copy = builder.AddInstruction(
511        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param));
512  
513    auto negate_a_1 = builder.AddInstruction(
514        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
515  
516    auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary(
517        vec1024_shape_, HloOpcode::kNegate, negate_a_1));
518  
519    auto negate_b_1 = builder.AddInstruction(
520        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
521  
522    auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary(
523        vec1024_shape_, HloOpcode::kNegate, negate_b_1));
524  
525    builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2}));
526  
527    HloComputation* entry_computation =
528        module->AddEntryComputation(builder.Build());
529  
530    TF_ASSERT_OK_AND_ASSIGN(bool changed,
531                            RunHloRematerialization(
532                                /*memory_limit_bytes=*/1 * 1024, module.get()));
533  
534    auto count_copies = [](const HloComputation* computation) {
535      int64 copy_count = 0;
536      for (auto* instruction : computation->instructions()) {
537        if (instruction->opcode() == HloOpcode::kCopy) {
538          copy_count++;
539        }
540      }
541      return copy_count;
542    };
543    EXPECT_TRUE(changed);
544  
545    EXPECT_EQ(count_copies(entry_computation), 1);
546  }
547  
548  class IndirectUseTest : public HloRematerializationTest,
549                          public ::testing::WithParamInterface<bool> {};
550  
TEST_P(IndirectUseTest,IndirectUseNotRematerialized)551  TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
552    // Test that an rematerializable instruction is not rematerialized if it has
553    // an indirect use. Test is parameterized on whether the value has an indirect
554    // use, and the instruction should be rematerialized iff the value has no
555    // indirect use. Module:
556    //
557    // Entry computation:
558    //   F32[] %param = {...}
559    //   F32[1024] %bcast = broadcast(%param)
560    //   F32[1024] %add_1 = add(%bcast, bcast)
561    //   F32[1024] %call = call(Subcomputation, {%add_1})
562    //   F32[1024] %add_2 = add(%bcast, call)
563    //   {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2)
564    //   F32[1024] %gte = GetTupleElememt(%tuple, 0)
565    //   F32[1024] %negate = negate(%gte)
566    //
567    // Subcomputation:
568    //   F32[1024] %param = {...}
569    //   F32[2048] %concat = concat({%param, %param})
570    //   F32[1024] %slice = slice(%concat)
571    //
572    // The value %bcast is live across the call and rematerialization of %bcast
573    // across that point would reduce peak memory use by 4KB. However, %bcast is
574    // used indirectly in the %negate so rematerialization should not happen.
575    //
576    // This test is parameterized on whether the broadcast has an indirect use or
577    // not. The indirect use is controlled by the index of the GetTupleElement
578    // instruction. If the element is 0, then the %negate operand aliases %bcast
579    // (ie %bcast is used indirectly by %negate), otherwise the %negate operand
580    // aliases %add_2.
581    const bool indirectly_used = GetParam();
582    auto module = CreateNewVerifiedModule();
583  
584    HloComputation* subcomputation = nullptr;
585    {
586      auto builder = HloComputation::Builder(TestName() + ".subcomputation");
587      auto param = builder.AddInstruction(
588          HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
589      auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
590          ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
591          /*dimension=*/0));
592      builder.AddInstruction(HloInstruction::CreateSlice(
593          vec1024_shape_, concat, /*start_indices=*/{0},
594          /*limit_indices=*/{1024}, /*strides=*/{1}));
595      subcomputation = module->AddEmbeddedComputation(builder.Build());
596    }
597  
598    auto builder = HloComputation::Builder(TestName());
599    auto param = builder.AddInstruction(
600        HloInstruction::CreateParameter(0, scalar_shape_, "param"));
601    auto bcast = builder.AddInstruction(
602        HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
603    auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
604        vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
605    auto call_1 = builder.AddInstruction(
606        HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
607    auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
608        vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
609    auto tuple =
610        builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2}));
611    auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
612        vec1024_shape_, tuple, indirectly_used ? 0 : 1));
613    builder.AddInstruction(
614        HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
615    HloComputation* entry_computation =
616        module->AddEntryComputation(builder.Build());
617  
618    EXPECT_EQ(entry_computation->instruction_count(), 8);
619  
620    // Pick a memory limit some where between 24KB (initial peak memory including
621    // parameter and output) and 20KB (peak memory possible with
622    // rematerialization).
623    TF_ASSERT_OK_AND_ASSIGN(bool changed,
624                            RunHloRematerialization(
625                                /*memory_limit_bytes=*/22 * 1024, module.get()));
626    // Rematerialization should only occur if the rematerializable instruction has
627    // no indirect uses.
628    if (indirectly_used) {
629      EXPECT_FALSE(changed);
630      EXPECT_EQ(entry_computation->instruction_count(), 8);
631    } else {
632      EXPECT_TRUE(changed);
633      EXPECT_EQ(entry_computation->instruction_count(), 9);
634    }
635  }
636  
637  INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
638                           ::testing::Values(true, false));
639  
640  }  // namespace
641  
642  }  // namespace xla
643