1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31
32 namespace xla {
33 namespace {
34
35 namespace op = xla::testing::opcode_matchers;
36
37 using ::testing::_;
38
39 class HloRematerializationTest : public HloTestBase {
40 protected:
41 // Creates and returns a computation which can benefit from
42 // rematerialization. The computation looks like:
43 //
44 // F32[1] %param = {...}
45 // F32[] %reshape = reshape(F32[], param)
46 // F32[1024] %bcast = broadcast(%param)
47 // F32[1024] %negate = negate(%bcast)
48 // F32[2048] %concat_1 = concat({%negate, %negate})
49 // F32[1] %slice_1 = slice(%concat_1, {0:1})
50 // F32[1025] %concat_2 = concat({%bcast, %slice_1})
51 // F32[1] %slice_2 = slice(%concat_2, {0:1});
52 //
53 // The instruction %bcast can be rematerialized before its use at %concat_2
54 // to reduce peak memory usage. This avoids %bcast and %concat_1 being
55 // simultaneously live. Peak memory use is about 16KB before rematerialization
56 // (during execution of %concat_1) and about 12KB after rematerializing %bcast
57 // for its use in %concat_2.
MakeRematerializableComputation(const string & suffix="")58 std::unique_ptr<HloComputation> MakeRematerializableComputation(
59 const string& suffix = "") {
60 auto builder = HloComputation::Builder(TestName() + suffix);
61 auto param = builder.AddInstruction(
62 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
63 auto reshape = builder.AddInstruction(
64 HloInstruction::CreateReshape(scalar_shape_, param));
65 auto bcast = builder.AddInstruction(
66 HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
67 auto negate = builder.AddInstruction(
68 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
69 auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
70 ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate},
71 /*dimension=*/0));
72 auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
73 vec1_shape_, concat_1, /*start_indices=*/{0},
74 /*limit_indices=*/{1},
75 /*strides=*/{1}));
76 auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
77 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
78 /*dimension=*/0));
79 // Add a final slice to make the parameter shape match the output shape
80 // which is necessary to use this computation in a while.
81 builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
82 /*start_indices=*/{0},
83 /*limit_indices=*/{1},
84 /*strides=*/{1}));
85 return builder.Build();
86 }
87
88 // Creates and returns a computation which includes a while and can benefit
89 // from rematerialization. The computation looks like:
90 //
91 // F32[] %param = {...}
92 // F32[1024] %bcast = broadcast(%param)
93 // F32[1] %slice_1 = slice(%bcast, {0:1})
94 // F32[1] %while = while(%slice_1, while_body, while_cond)
95 // F32[1025] %concat = concat({%bcast, %while})
96 // F32[1] %slice_2 = slice(%concat, {0:1});
97 //
98 // The instruction %bcast can be rematerialized before its use at %concat to
99 // reduce peak memory usage. This avoids %bcast being live during execution of
100 // the while. Peak memory use is maximum of 8K and 4K plus the memory use of
101 // the while subcomputations.
MakeRematerializableWhileComputation(HloComputation * while_cond,HloComputation * while_body,const string & suffix="")102 std::unique_ptr<HloComputation> MakeRematerializableWhileComputation(
103 HloComputation* while_cond, HloComputation* while_body,
104 const string& suffix = "") {
105 auto builder = HloComputation::Builder(TestName() + suffix);
106 auto param = builder.AddInstruction(
107 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
108 auto reshape = builder.AddInstruction(
109 HloInstruction::CreateReshape(scalar_shape_, param));
110 auto bcast = builder.AddInstruction(
111 HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
112 auto slice_1 = builder.AddInstruction(
113 HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
114 /*limit_indices=*/{1},
115 /*strides=*/{1}));
116 auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
117 vec1_shape_, while_cond, while_body, slice_1));
118 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
119 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst},
120 /*dimension=*/0));
121 builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
122 /*start_indices=*/{0},
123 /*limit_indices=*/{1},
124 /*strides=*/{1}));
125 return builder.Build();
126 }
127
128 // Create and return a trivial computation appropriate for use as a while
129 // condition.
MakeConditionComputation()130 std::unique_ptr<HloComputation> MakeConditionComputation() {
131 auto builder = HloComputation::Builder(TestName() + ".cond");
132 builder.AddInstruction(
133 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
134 builder.AddInstruction(
135 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
136 return builder.Build();
137 }
138
139 // Return the byte size of the top-level buffer of the given shape.
ByteSizeOf(const Shape & shape)140 static int64 ByteSizeOf(const Shape& shape) {
141 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
142 }
143
RunHloRematerialization(int64 memory_limit_bytes,HloModule * module)144 StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
145 HloModule* module) {
146 TF_EXPECT_OK(verifier().Run(module).status());
147 HloMemoryScheduler scheduler(
148 [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
149 DefaultMemoryScheduler);
150 TF_EXPECT_OK(scheduler.Run(module).status());
151 HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
152 /*sizes=*/nullptr);
153 return remat.Run(module);
154 }
155
156 // Various shapes used in the canned computations.
157 const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
158 const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
159 const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024});
160 };
161
162 // Test rematerialization of a single computation produced by
163 // MakeRematerializableComputation.
TEST_F(HloRematerializationTest,SingleComputation)164 TEST_F(HloRematerializationTest, SingleComputation) {
165 auto module = CreateNewVerifiedModule();
166 HloComputation* computation =
167 module->AddEntryComputation(MakeRematerializableComputation());
168
169 // Find and save the original broadcast instruction which should be
170 // rematerialized.
171 const HloInstruction* slice = computation->root_instruction();
172 ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
173 const HloInstruction* concat = slice->operand(0);
174 const HloInstruction* bcast = concat->operand(0);
175
176 // Computation requires 16KB without rematerialization, but uses only 12KB
177 // with rematerialization so pick a memory limit between these values (14KB).
178 TF_ASSERT_OK_AND_ASSIGN(bool changed,
179 RunHloRematerialization(
180 /*memory_limit_bytes=*/14 * 1024, module.get()));
181 EXPECT_TRUE(changed);
182
183 // Root should not have changed.
184 EXPECT_EQ(computation->root_instruction(), slice);
185
186 // The broadcast should have been rematerialized.
187 const HloInstruction* remat_bcast = concat->operand(0);
188 EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
189
190 // The rematerialized broadcast should be immediate before the concat in the
191 // sequence.
192 EXPECT_EQ(module->schedule()
193 .sequence(computation)
194 .instructions()[computation->instruction_count() - 2],
195 concat);
196 EXPECT_EQ(module->schedule()
197 .sequence(computation)
198 .instructions()[computation->instruction_count() - 3],
199 remat_bcast);
200 }
201
202 // Test rematerialization of a single computation produced by
203 // MakeRematerializableComputation but with a sufficiently high memory limit
204 // such that no instructions are rematerialized.
TEST_F(HloRematerializationTest,SingleComputationNoRematerialization)205 TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
206 auto module = CreateNewVerifiedModule();
207 HloComputation* computation =
208 module->AddEntryComputation(MakeRematerializableComputation());
209
210 EXPECT_EQ(computation->instruction_count(), 8);
211
212 TF_ASSERT_OK_AND_ASSIGN(bool changed,
213 RunHloRematerialization(
214 /*memory_limit_bytes=*/20 * 1024, module.get()));
215
216 // No instructions should have been materialized.
217 EXPECT_FALSE(changed);
218 EXPECT_EQ(computation->instruction_count(), 8);
219 }
220
221 // Test rematerialization of a computation which calls another computation via a
222 // while. Both the entry computation and while body computation can have memory
223 // usage reduced via rematerialization however the memory limit is set such that
224 // only one computation needs to have an instruction rematerialized. The entry
225 // computation should be the one chosen because rematerialization in the while
226 // will presumably be more expensive.
TEST_F(HloRematerializationTest,RematerializeAroundWhile)227 TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
228 auto module = CreateNewVerifiedModule();
229
230 auto cond_builder = HloComputation::Builder(TestName() + ".cond");
231 cond_builder.AddInstruction(
232 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
233 cond_builder.AddInstruction(
234 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
235 HloComputation* while_cond =
236 module->AddEmbeddedComputation(cond_builder.Build());
237
238 HloComputation* body_computation = module->AddEmbeddedComputation(
239 MakeRematerializableComputation(/*suffix=*/".body"));
240 HloComputation* entry_computation =
241 module->AddEntryComputation(MakeRematerializableWhileComputation(
242 while_cond, /*while_body=*/body_computation));
243
244 EXPECT_EQ(entry_computation->instruction_count(), 7);
245 EXPECT_EQ(body_computation->instruction_count(), 8);
246
247 // The body computation uses 16KB and the entry computation uses 2KB at the
248 // while so the peak memory use of the module is 18KB. Set the memory limit a
249 // bit lower (17KB) to force rematerialization of the entry computation.
250 TF_ASSERT_OK_AND_ASSIGN(bool changed,
251 RunHloRematerialization(
252 /*memory_limit_bytes=*/17 * 1024, module.get()));
253 EXPECT_TRUE(changed);
254
255 // Only the entry computation should have a rematerialized instruction added.
256 EXPECT_EQ(entry_computation->instruction_count(), 8);
257 EXPECT_EQ(body_computation->instruction_count(), 8);
258 }
259
260 // Test rematerialization of a computation which calls another computation via a
261 // while. Both the entry computation and while body computation should have
262 // computations rematerialized.
TEST_F(HloRematerializationTest,RematerializeEntryAndWhileBody)263 TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
264 auto module = CreateNewVerifiedModule();
265
266 auto cond_builder = HloComputation::Builder(TestName() + ".cond");
267 cond_builder.AddInstruction(
268 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
269 cond_builder.AddInstruction(
270 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
271 HloComputation* while_cond =
272 module->AddEmbeddedComputation(cond_builder.Build());
273
274 HloComputation* body_computation = module->AddEmbeddedComputation(
275 MakeRematerializableComputation(/*suffix=*/".body"));
276 HloComputation* entry_computation =
277 module->AddEntryComputation(MakeRematerializableWhileComputation(
278 while_cond, /*while_body=*/body_computation));
279
280 EXPECT_EQ(entry_computation->instruction_count(), 7);
281 EXPECT_EQ(body_computation->instruction_count(), 8);
282
283 TF_ASSERT_OK_AND_ASSIGN(bool changed,
284 RunHloRematerialization(
285 /*memory_limit_bytes=*/15 * 1024, module.get()));
286 EXPECT_TRUE(changed);
287
288 // Both computations should have rematerialized instructions added.
289 EXPECT_EQ(entry_computation->instruction_count(), 9);
290 EXPECT_EQ(body_computation->instruction_count(), 9);
291 }
292
293 // Test rematerialization of a doubly nested computation. All computations
294 // should have an instruction rematerialized.
TEST_F(HloRematerializationTest,RematerializeNestedComputations)295 TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
296 auto module = CreateNewVerifiedModule();
297
298 auto cond_builder = HloComputation::Builder(TestName() + ".cond");
299 cond_builder.AddInstruction(
300 HloInstruction::CreateParameter(0, vec1_shape_, "param"));
301 cond_builder.AddInstruction(
302 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
303 HloComputation* while_cond =
304 module->AddEmbeddedComputation(cond_builder.Build());
305
306 HloComputation* inner_computation = module->AddEmbeddedComputation(
307 MakeRematerializableComputation(/*suffix=*/".inner"));
308 HloComputation* middle_computation =
309 module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
310 while_cond, /*while_body=*/inner_computation,
311 /*suffix=*/".middle"));
312 HloComputation* entry_computation =
313 module->AddEntryComputation(MakeRematerializableWhileComputation(
314 while_cond, /*while_body=*/middle_computation));
315
316 EXPECT_EQ(entry_computation->instruction_count(), 7);
317 EXPECT_EQ(middle_computation->instruction_count(), 7);
318 EXPECT_EQ(inner_computation->instruction_count(), 8);
319
320 // If all computations are maximally rematerialized then peak memory usage is
321 // ~12K so pick something slightly larger.
322 TF_ASSERT_OK_AND_ASSIGN(bool changed,
323 RunHloRematerialization(
324 /*memory_limit_bytes=*/13 * 1024, module.get()));
325 EXPECT_TRUE(changed);
326
327 // All computations should have rematerialized instructions added.
328 EXPECT_EQ(entry_computation->instruction_count(), 9);
329 EXPECT_EQ(middle_computation->instruction_count(), 9);
330 EXPECT_EQ(inner_computation->instruction_count(), 9);
331 }
332
TEST_F(HloRematerializationTest,RngNotRematerialized)333 TEST_F(HloRematerializationTest, RngNotRematerialized) {
334 // Test that a single rng is not rematerialized:
335 //
336 // Entry computation:
337 // F32[] %param = {...}
338 // F32[1024] rng = rng(param)
339 // F32[1024] tanh = tanh(rng)
340 // F32[1024] exp = exp(rng)
341 // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng +
342 // // tanh + exp
343 //
344 // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 +
345 // // rng + tanh + exp
346 //
347 // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 +
348 // // rng + tanh + exp
349 auto module = CreateNewVerifiedModule();
350
351 auto builder = HloComputation::Builder(TestName());
352 auto param = builder.AddInstruction(
353 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
354 auto rng = builder.AddInstruction(HloInstruction::CreateRng(
355 vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param}));
356 auto tanh = builder.AddInstruction(
357 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng));
358 auto exp = builder.AddInstruction(
359 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng));
360 auto add_0 = builder.AddInstruction(
361 HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh));
362 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
363 vec1024_shape_, HloOpcode::kAdd, rng,
364 builder.AddInstruction(HloInstruction::CreateBinary(
365 vec1024_shape_, HloOpcode::kAdd, exp, add_0))));
366 builder.AddInstruction(HloInstruction::CreateBinary(
367 vec1024_shape_, HloOpcode::kAdd, rng,
368 builder.AddInstruction(HloInstruction::CreateBinary(
369 vec1024_shape_, HloOpcode::kAdd, tanh, add_1))));
370 HloComputation* entry_computation =
371 module->AddEntryComputation(builder.Build());
372
373 auto count_rngs = [](const HloComputation* computation) {
374 int64 rng_count = 0;
375 for (auto* instruction : computation->instructions()) {
376 if (instruction->opcode() == HloOpcode::kRng) {
377 ++rng_count;
378 }
379 }
380 return rng_count;
381 };
382 // Before rematerialization there should be a single broadcast rng in
383 // the graph.
384 ASSERT_EQ(count_rngs(entry_computation), 1);
385 const int64 original_instruction_count =
386 entry_computation->instruction_count();
387 // Pick a memory limit some where between 24KB (initial peak memory including
388 // parameter and output) and 20KB (peak memory possible with
389 // rematerialization).
390 TF_ASSERT_OK_AND_ASSIGN(
391 bool changed,
392 RunHloRematerialization(
393 /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
394 EXPECT_TRUE(changed);
395 // The rng should not have been rematerialized.
396 EXPECT_EQ(count_rngs(entry_computation), 1);
397 // There should have been rematerialization.
398 EXPECT_GT(entry_computation->instruction_count(), original_instruction_count);
399 }
400
TEST_F(HloRematerializationTest,InstructionRematerializedMultipleTimes)401 TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
402 // Test that a single instruction is rematerialized several times. Module:
403 //
404 // Entry computation:
405 // F32[] %param = {...}
406 // F32[1024] %bcast = broadcast(%param)
407 // F32[1024] %add_1 = add(%bcast, bcast)
408 // F32[1024] %call_1 = call(Subcomputation, {%add_1})
409 // F32[1024] %add_2 = add(%bcast, call_1)
410 // F32[1024] %call_2 = call(SubComputation, {%add_2})
411 // F32[1024] %add_3 = add(%bcast, call_2)
412 // F32[1024] %call_3 = call(Subcomputation, {%add_3})
413 // F32[1024] %add_4 = add(%bcast, call_3)
414 //
415 // Subcomputation:
416 // F32[1024] %param = {...}
417 // F32[2048] %concat = concat({%param, %param})
418 // F32[1024] %slice = slice(%concat)
419 //
420 // The value %bcast is live across each call of Subcomputation (which requires
421 // 8KB) though the value is not used in the calls. Rematerializing %bcast
422 // across these calls reduces peak memory use from ~20KB down to ~16KB.
423 auto module = CreateNewVerifiedModule();
424
425 HloComputation* subcomputation = nullptr;
426 {
427 auto builder = HloComputation::Builder(TestName() + ".subcomputation");
428 auto param = builder.AddInstruction(
429 HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
430 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
431 ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
432 /*dimension=*/0));
433 builder.AddInstruction(HloInstruction::CreateSlice(
434 vec1024_shape_, concat, /*start_indices=*/{0},
435 /*limit_indices=*/{1024}, /*strides=*/{1}));
436 subcomputation = module->AddEmbeddedComputation(builder.Build());
437 }
438
439 auto builder = HloComputation::Builder(TestName());
440 auto param = builder.AddInstruction(
441 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
442 auto bcast = builder.AddInstruction(
443 HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
444 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
445 vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
446 auto call_1 = builder.AddInstruction(
447 HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
448 auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
449 vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
450 auto call_2 = builder.AddInstruction(
451 HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation));
452 auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary(
453 vec1024_shape_, HloOpcode::kAdd, bcast, call_2));
454 auto call_3 = builder.AddInstruction(
455 HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation));
456 auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
457 vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
458 HloComputation* entry_computation =
459 module->AddEntryComputation(builder.Build());
460
461 auto count_broadcasts = [](const HloComputation* computation) {
462 int64 bcast_count = 0;
463 for (auto* instruction : computation->instructions()) {
464 if (instruction->opcode() == HloOpcode::kBroadcast) {
465 bcast_count++;
466 }
467 }
468 return bcast_count;
469 };
470
471 // Before rematerialization there should be a single broadcast instruction in
472 // the graph.
473 EXPECT_EQ(count_broadcasts(entry_computation), 1);
474 EXPECT_EQ(entry_computation->instruction_count(), 9);
475
476 EXPECT_EQ(add_2->operand(0), bcast);
477 EXPECT_EQ(add_3->operand(0), bcast);
478 EXPECT_EQ(add_4->operand(0), bcast);
479
480 // Pick a memory limit some where between 24KB (initial peak memory including
481 // parameter and output) and 20KB (peak memory possible with
482 // rematerialization).
483 TF_ASSERT_OK_AND_ASSIGN(bool changed,
484 RunHloRematerialization(
485 /*memory_limit_bytes=*/22 * 1024, module.get()));
486 EXPECT_TRUE(changed);
487
488 // The broadcast should have been rematerialized 3 times.
489 EXPECT_EQ(count_broadcasts(entry_computation), 4);
490 EXPECT_EQ(entry_computation->instruction_count(), 12);
491
492 // The operands of add_2, add_3, and add_4 should all be rematerialized
493 // broadcasts.
494 EXPECT_NE(add_2->operand(0), bcast);
495 EXPECT_THAT(add_2->operand(0), op::Broadcast(param));
496 EXPECT_NE(add_3->operand(0), bcast);
497 EXPECT_THAT(add_3->operand(0), op::Broadcast(param));
498 EXPECT_NE(add_4->operand(0), bcast);
499 EXPECT_THAT(add_4->operand(0), op::Broadcast(param));
500 }
501
TEST_F(HloRematerializationTest,CopyNotRematerialized)502 TEST_F(HloRematerializationTest, CopyNotRematerialized) {
503 // Test that copies are not rematerialized.
504 auto module = CreateNewVerifiedModule();
505
506 auto builder = HloComputation::Builder(TestName());
507 auto param = builder.AddInstruction(
508 HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
509
510 auto copy = builder.AddInstruction(
511 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param));
512
513 auto negate_a_1 = builder.AddInstruction(
514 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
515
516 auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary(
517 vec1024_shape_, HloOpcode::kNegate, negate_a_1));
518
519 auto negate_b_1 = builder.AddInstruction(
520 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
521
522 auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary(
523 vec1024_shape_, HloOpcode::kNegate, negate_b_1));
524
525 builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2}));
526
527 HloComputation* entry_computation =
528 module->AddEntryComputation(builder.Build());
529
530 TF_ASSERT_OK_AND_ASSIGN(bool changed,
531 RunHloRematerialization(
532 /*memory_limit_bytes=*/1 * 1024, module.get()));
533
534 auto count_copies = [](const HloComputation* computation) {
535 int64 copy_count = 0;
536 for (auto* instruction : computation->instructions()) {
537 if (instruction->opcode() == HloOpcode::kCopy) {
538 copy_count++;
539 }
540 }
541 return copy_count;
542 };
543 EXPECT_TRUE(changed);
544
545 EXPECT_EQ(count_copies(entry_computation), 1);
546 }
547
548 class IndirectUseTest : public HloRematerializationTest,
549 public ::testing::WithParamInterface<bool> {};
550
TEST_P(IndirectUseTest,IndirectUseNotRematerialized)551 TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
552 // Test that an rematerializable instruction is not rematerialized if it has
553 // an indirect use. Test is parameterized on whether the value has an indirect
554 // use, and the instruction should be rematerialized iff the value has no
555 // indirect use. Module:
556 //
557 // Entry computation:
558 // F32[] %param = {...}
559 // F32[1024] %bcast = broadcast(%param)
560 // F32[1024] %add_1 = add(%bcast, bcast)
561 // F32[1024] %call = call(Subcomputation, {%add_1})
562 // F32[1024] %add_2 = add(%bcast, call)
563 // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2)
564 // F32[1024] %gte = GetTupleElememt(%tuple, 0)
565 // F32[1024] %negate = negate(%gte)
566 //
567 // Subcomputation:
568 // F32[1024] %param = {...}
569 // F32[2048] %concat = concat({%param, %param})
570 // F32[1024] %slice = slice(%concat)
571 //
572 // The value %bcast is live across the call and rematerialization of %bcast
573 // across that point would reduce peak memory use by 4KB. However, %bcast is
574 // used indirectly in the %negate so rematerialization should not happen.
575 //
576 // This test is parameterized on whether the broadcast has an indirect use or
577 // not. The indirect use is controlled by the index of the GetTupleElement
578 // instruction. If the element is 0, then the %negate operand aliases %bcast
579 // (ie %bcast is used indirectly by %negate), otherwise the %negate operand
580 // aliases %add_2.
581 const bool indirectly_used = GetParam();
582 auto module = CreateNewVerifiedModule();
583
584 HloComputation* subcomputation = nullptr;
585 {
586 auto builder = HloComputation::Builder(TestName() + ".subcomputation");
587 auto param = builder.AddInstruction(
588 HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
589 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
590 ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
591 /*dimension=*/0));
592 builder.AddInstruction(HloInstruction::CreateSlice(
593 vec1024_shape_, concat, /*start_indices=*/{0},
594 /*limit_indices=*/{1024}, /*strides=*/{1}));
595 subcomputation = module->AddEmbeddedComputation(builder.Build());
596 }
597
598 auto builder = HloComputation::Builder(TestName());
599 auto param = builder.AddInstruction(
600 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
601 auto bcast = builder.AddInstruction(
602 HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
603 auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
604 vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
605 auto call_1 = builder.AddInstruction(
606 HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
607 auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
608 vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
609 auto tuple =
610 builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2}));
611 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
612 vec1024_shape_, tuple, indirectly_used ? 0 : 1));
613 builder.AddInstruction(
614 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
615 HloComputation* entry_computation =
616 module->AddEntryComputation(builder.Build());
617
618 EXPECT_EQ(entry_computation->instruction_count(), 8);
619
620 // Pick a memory limit some where between 24KB (initial peak memory including
621 // parameter and output) and 20KB (peak memory possible with
622 // rematerialization).
623 TF_ASSERT_OK_AND_ASSIGN(bool changed,
624 RunHloRematerialization(
625 /*memory_limit_bytes=*/22 * 1024, module.get()));
626 // Rematerialization should only occur if the rematerializable instruction has
627 // no indirect uses.
628 if (indirectly_used) {
629 EXPECT_FALSE(changed);
630 EXPECT_EQ(entry_computation->instruction_count(), 8);
631 } else {
632 EXPECT_TRUE(changed);
633 EXPECT_EQ(entry_computation->instruction_count(), 9);
634 }
635 }
636
637 INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
638 ::testing::Values(true, false));
639
640 } // namespace
641
642 } // namespace xla
643