1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/buffer_value.h"
28 #include "tensorflow/compiler/xla/service/call_graph.h"
29 #include "tensorflow/compiler/xla/service/copy_insertion.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
37 #include "tensorflow/compiler/xla/service/hlo_parser.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/test.h"
41 #include "tensorflow/compiler/xla/test_helpers.h"
42 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/platform/macros.h"
47
48 namespace xla {
49 namespace {
50
51 using ::testing::UnorderedElementsAre;
52
53 // DFS visitor that collects the instructions referenced by a computation
54 // without descending into nested computations, i.e., only from the operands.
55 class InstructionListVisitor : public DfsHloVisitorWithDefault {
56 public:
InstructionListVisitor(const HloInstruction * root)57 explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
58
DefaultAction(HloInstruction * hlo)59 Status DefaultAction(HloInstruction* hlo) override {
60 // For each instruction, just push it on the list after walking the
61 // operands.
62 instructions_.push_back(hlo);
63 VLOG(0) << "List instruction " << hlo->ToString();
64 return Status::OK();
65 }
66
GetInstructions()67 std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
68
69 private:
70 // The instruction root of the computation.
71 const HloInstruction* root_;
72
73 // The full set of instructions found (may be duplicates, e.g., kParameter).
74 std::vector<const HloInstruction*> instructions_;
75
76 TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
77 };
78
GetInstructions(HloInstruction * root)79 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
80 InstructionListVisitor main_list(root);
81 TF_CHECK_OK(root->Accept(&main_list));
82 return main_list.GetInstructions();
83 }
84
85 class BufferAssignmentTest : public HloTestBase {
86 protected:
~BufferAssignmentTest()87 ~BufferAssignmentTest() override {}
88
RunBufferAssignment(HloModule * module,int64 alignment=1)89 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
90 int64 alignment = 1) {
91 return BufferAssigner::Run(
92 module, absl::make_unique<DependencyHloOrdering>(module),
93 backend().compiler()->BufferSizeBytesFunction(),
94 [alignment](LogicalBuffer::Color) { return alignment; },
95 /*allow_input_output_aliasing=*/false,
96 /*allocate_buffers_for_constants=*/true)
97 .ConsumeValueOrDie();
98 }
99
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64 alignment=1)100 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
101 HloModule* module, int64 alignment = 1) {
102 return BufferAssigner::Run(
103 module, absl::make_unique<DependencyHloOrdering>(module),
104 backend().compiler()->BufferSizeBytesFunction(),
105 [alignment](LogicalBuffer::Color) { return alignment; },
106 /*allow_input_output_aliasing=*/false,
107 /*allocate_buffers_for_constants=*/false)
108 .ConsumeValueOrDie();
109 }
110
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64 alignment=1)111 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
112 HloModule* module, int64 alignment = 1) {
113 auto reuse_checker = [](const BufferAssignment& assignment,
114 const BufferAllocation& alloc,
115 const LogicalBuffer& buffer) {
116 return (buffer.instruction()->opcode() != HloOpcode::kAdd);
117 };
118 return BufferAssigner::Run(
119 module, absl::make_unique<DependencyHloOrdering>(module),
120 backend().compiler()->BufferSizeBytesFunction(),
121 [alignment](LogicalBuffer::Color) { return alignment; },
122 /*allow_input_output_aliasing=*/false,
123 /*allocate_buffers_for_constants=*/false,
124 /*colorer=*/BufferLiveness::DefaultColorer(),
125 /*reuse_checker=*/reuse_checker)
126 .ConsumeValueOrDie();
127 }
128
RunColoredBufferAssignment(HloModule * module,BufferLiveness::Colorer colorer,int64 alignment=1)129 std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
130 HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
131 return BufferAssigner::Run(
132 module, absl::make_unique<DependencyHloOrdering>(module),
133 backend().compiler()->BufferSizeBytesFunction(),
134 [alignment](LogicalBuffer::Color) { return alignment; },
135 /*allow_input_output_aliasing=*/false,
136 /*allocate_buffers_for_constants=*/true, std::move(colorer))
137 .ConsumeValueOrDie();
138 }
139
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64 alignment=1)140 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
141 HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
142 int64 alignment = 1) {
143 HloSchedule schedule(module);
144 schedule.set_sequence(module->entry_computation(), instruction_sequence);
145 return BufferAssigner::Run(
146 module, absl::make_unique<SequentialHloOrdering>(schedule),
147 backend().compiler()->BufferSizeBytesFunction(),
148 [alignment](LogicalBuffer::Color) { return alignment; },
149 /*allow_input_output_aliasing=*/false,
150 /*allocate_buffers_for_constants=*/true)
151 .ConsumeValueOrDie();
152 }
153
154 // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)155 std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
156 auto builder = HloComputation::Builder(name);
157 auto param =
158 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
159 auto value = builder.AddInstruction(
160 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
161 builder.AddInstruction(
162 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
163 return builder.Build();
164 }
165
BuildReduceComputation(const string & name)166 std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
167 auto builder = HloComputation::Builder(name);
168 auto param =
169 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
170 auto param2 =
171 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
172 builder.AddInstruction(
173 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
174 return builder.Build();
175 }
176
177 // Builds a simple compare-to-limit (x < 4) computation for a While.
178 //
179 // condition:
180 // const4[s32] -----------------------------------\
181 // \
182 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
183 //
BuildWhileConditionComputation(const string & name)184 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
185 const string& name) {
186 auto builder = HloComputation::Builder(name);
187 auto const4 = builder.AddInstruction(
188 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
189 auto param = builder.AddInstruction(
190 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
191 auto index = builder.AddInstruction(
192 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
193 builder.AddInstruction(
194 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
195 const4, ComparisonDirection::kLt));
196 return builder.Build();
197 }
198
199 // Builds a simple body computation for a While.
200 //
201 // body:
202 // constv[f32[4]] --------------------------------------\
203 // \
204 // /--- get-tuple-elementv[1] --- addv ---\
205 // param[(s32,f32[4])] ---| tuple
206 // \--- get-tuple-elementc[0] --- addc ---/
207 // /
208 // const1[s32] -----------------------------------------/
209 //
BuildWhileBodyComputation(const string & name)210 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
211 const string& name) {
212 auto builder = HloComputation::Builder(name);
213 auto const1 = builder.AddInstruction(
214 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
215 auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
216 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
217 auto param = builder.AddInstruction(
218 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
219 auto indexc = builder.AddInstruction(
220 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
221 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
222 indexc->shape(), HloOpcode::kAdd, indexc, const1));
223 auto indexv = builder.AddInstruction(
224 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
225 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
226 constv->shape(), HloOpcode::kAdd, indexv, constv));
227 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
228 return builder.Build();
229 }
230
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)231 std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
232 HloOpcode opcode, const string& name) {
233 auto builder = HloComputation::Builder(name);
234 auto param =
235 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
236 builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
237 return builder.Build();
238 }
239
240 // Verifies that the given instruction hlo has a valid input buffer assigned,
241 // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)242 const BufferAllocation& GetAssignedInputAllocation(
243 const BufferAssignment& buffers, HloInstruction* hlo) {
244 LOG(INFO) << "Checking input: " << hlo->ToString();
245 const BufferAllocation& buffer =
246 *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
247 EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
248 return buffer;
249 }
250
251 // Verifies that the given instruction hlo has a valid output buffer
252 // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)253 const BufferAllocation& GetAssignedOutputAllocation(
254 const BufferAssignment& buffers, HloInstruction* hlo) {
255 LOG(INFO) << "Checking output: " << hlo->ToString();
256 const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
257 return buffer;
258 }
259
260 // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)261 const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
262 const HloInstruction* hlo,
263 const ShapeIndex& index) {
264 return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
265 }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)266 const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
267 const HloInstruction* hlo) {
268 return *buffers.GetUniqueTopLevelSlice(hlo)
269 .ConsumeValueOrDie()
270 .allocation();
271 }
272
273 // Verifies that all instructions in the given instruction list except
274 // kConstant have assigned buffers, and returns their total size. If min_index
275 // and max_index are not nullptr, the minimum and maximum buffer indices in
276 // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)277 int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
278 const BufferAssignment& buffers) {
279 // Verifies all instructions have buffers, and gets the index ranges.
280 for (const HloInstruction* hlo : instructions) {
281 if (!buffers.HasTopLevelAllocation(hlo)) {
282 // If `hlo` has no assigned buffer, it is either a constant or a nested
283 // parameter.
284 EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
285 HloOpcode::kParameter == hlo->opcode());
286 continue;
287 }
288 }
289
290 // Gets the total size of all buffers assigned.
291 int64 total_size = 0;
292 for (auto& allocation : buffers.Allocations()) {
293 total_size += allocation.size();
294 }
295 return total_size;
296 }
297
298 // Shapes for use in the examples.
299 Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
300 Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
301 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
302 Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
303 Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
304 Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
305 Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
306 Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
307 };
308
309 // Returns true if the buffers assigned to instructions in "a" are distinct
310 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)311 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
312 const std::vector<const HloInstruction*>& b,
313 const BufferAssignment& assignment) {
314 absl::flat_hash_set<BufferAllocation::Slice> a_slices;
315 for (const HloInstruction* instruction : a) {
316 if (assignment.HasTopLevelAllocation(instruction)) {
317 a_slices.insert(
318 assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
319 }
320 }
321
322 for (const HloInstruction* instruction : b) {
323 if (assignment.HasTopLevelAllocation(instruction)) {
324 if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
325 .ConsumeValueOrDie())) {
326 return false;
327 }
328 }
329 }
330 return true;
331 }
332
333 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)334 TEST_F(BufferAssignmentTest, ScalarConstant) {
335 auto builder = HloComputation::Builder(TestName());
336 auto const0 = builder.AddInstruction(
337 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
338 auto module = CreateNewVerifiedModule();
339 module->AddEntryComputation(builder.Build());
340
341 {
342 auto buffers = RunBufferAssignment(module.get());
343 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
344 }
345
346 {
347 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
348 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
349 }
350 }
351
TEST_F(BufferAssignmentTest,BufferForConst)352 TEST_F(BufferAssignmentTest, BufferForConst) {
353 // Addition of two vector constants: checks that internal constant nodes have
354 // no buffers assigned, and their consumer has a buffer.
355 auto builder = HloComputation::Builder(TestName());
356 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
357 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
358 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
359 LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
360 auto add = builder.AddInstruction(
361 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
362 auto module = CreateNewVerifiedModule();
363 module->AddEntryComputation(builder.Build());
364
365 {
366 auto buffers = RunBufferAssignment(module.get());
367 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
368 EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
369 GetAssignedOutputAllocation(*buffers, add);
370 }
371 {
372 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
373 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
374 EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
375 GetAssignedOutputAllocation(*buffers, add);
376 }
377 }
378
TEST_F(BufferAssignmentTest,HasAllocationAt)379 TEST_F(BufferAssignmentTest, HasAllocationAt) {
380 // Create a tuple with non-const and const elements and check that
381 // HasAllocationAt works correctly.
382 auto builder = HloComputation::Builder(TestName());
383 auto param0 = builder.AddInstruction(
384 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
385 auto constant = builder.AddInstruction(
386 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
387 auto negate = builder.AddInstruction(
388 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
389 auto tuple = builder.AddInstruction(
390 HloInstruction::CreateTuple({negate, param0, constant}));
391 auto module = CreateNewVerifiedModule();
392 module->AddEntryComputation(builder.Build());
393
394 auto buffers = RunBufferAssignment(module.get());
395 // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
396 // reports for the instruction directly.
397 EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
398 buffers->HasAllocationAt(tuple, /*index=*/{}));
399 EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
400 buffers->HasAllocationAt(tuple, /*index=*/{0}));
401 EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
402 buffers->HasAllocationAt(tuple, /*index=*/{1}));
403 EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
404 buffers->HasAllocationAt(tuple, /*index=*/{2}));
405 }
406
TEST_F(BufferAssignmentTest,BufferForOutputConst)407 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
408 // This computation copies a constant to output.
409 auto builder = HloComputation::Builder(TestName());
410 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
411 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
412 auto copy = builder.AddInstruction(
413 HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
414 auto module = CreateNewVerifiedModule();
415 module->AddEntryComputation(builder.Build());
416
417 auto buffers = RunBufferAssignment(module.get());
418 // The copy node now has an output buffer.
419 GetAssignedOutputAllocation(*buffers, copy);
420 }
421
TEST_F(BufferAssignmentTest,Basic)422 TEST_F(BufferAssignmentTest, Basic) {
423 // paramscalar ------- (mul) -- (add) -- (sub)
424 // / / /
425 // param0[100] -------/ / /
426 // / /
427 // param1[100] --------------/--------/
428 auto builder = HloComputation::Builder(TestName());
429 auto paramscalar =
430 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
431 auto broadcast = builder.AddInstruction(
432 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
433 auto param0 = builder.AddInstruction(
434 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
435 auto param1 = builder.AddInstruction(
436 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
437 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
438 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
439 auto add = builder.AddInstruction(
440 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
441 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
442 f32vec100_, HloOpcode::kSubtract, add, param1));
443 auto module = CreateNewVerifiedModule();
444 module->AddEntryComputation(builder.Build());
445
446 auto buffers = RunBufferAssignment(module.get());
447
448 // Distinct input buffers were assigned for parameters.
449 BufferAllocation paramscalar_buffer =
450 GetAssignedInputAllocation(*buffers, paramscalar);
451 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
452 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
453 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
454 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
455 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
456
457 // The mul node has a valid buffer assigned, doesn't share with input.
458 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
459 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
460
461 // The add node can reuse the mul node's buffer.
462 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
463 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
464
465 // The sub node has a valid output buffer assigned.
466 GetAssignedOutputAllocation(*buffers, sub);
467 }
468
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)469 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
470 // If an input buffer and output buffer aliases, the input buffer can be
471 // reused for other intermediate results.
472 //
473 // param0[100] ----- (neg1) -- (neg2)
474 // | |
475 // + -------- Aliased ---------+
476
477 auto builder = HloComputation::Builder(TestName());
478
479 auto param = builder.AddInstruction(
480 HloInstruction::CreateParameter(0, f32vec100_, "p0"));
481 auto neg_1 = builder.AddInstruction(
482 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
483 auto neg_2 = builder.AddInstruction(
484 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
485
486 auto module = CreateNewVerifiedModule();
487 module->AddEntryComputation(builder.Build());
488
489 TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias(
490 {}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
491
492 auto buffers = RunBufferAssignment(module.get());
493
494 BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
495 BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
496 BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
497
498 // Everything use one buffer.
499 EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
500 EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
501 }
502
TEST_F(BufferAssignmentTest,AddCannotReuse)503 TEST_F(BufferAssignmentTest, AddCannotReuse) {
504 // Pass in a special rule to indicate that "add" cannot reuse any buffer.
505 //
506 // paramscalar ------- (mul) -- (add) -- (sub)
507 // / / /
508 // param0[100] -------/ / /
509 // / /
510 // param1[100] --------------/--------/
511 auto builder = HloComputation::Builder(TestName());
512 auto paramscalar =
513 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
514 auto broadcast = builder.AddInstruction(
515 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
516 auto param0 = builder.AddInstruction(
517 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
518 auto param1 = builder.AddInstruction(
519 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
520 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
521 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
522 auto add = builder.AddInstruction(
523 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
524 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
525 f32vec100_, HloOpcode::kSubtract, add, param1));
526 auto module = CreateNewVerifiedModule();
527 module->AddEntryComputation(builder.Build());
528
529 auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
530
531 // Distinct input buffers were assigned for parameters.
532 BufferAllocation paramscalar_buffer =
533 GetAssignedInputAllocation(*buffers, paramscalar);
534 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
535 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
536 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
537 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
538 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
539
540 // The mul node has a valid buffer assigned, doesn't share with input.
541 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
542 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
543
544 // The add node cannot reuse the mul node's buffer since we told buffer
545 // assignment so.
546 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
547 EXPECT_NE(add_buffer.index(), mul_buffer.index());
548
549 // The sub node has a valid output buffer assigned.
550 GetAssignedOutputAllocation(*buffers, sub);
551 }
552
TEST_F(BufferAssignmentTest,BasicUniquelyColored)553 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
554 // paramscalar ------- (mul) -- (add) -- (sub)
555 // / / /
556 // param0[100] -------/ / /
557 // / /
558 // param1[100] --------------/--------/
559 // The output of each op is colored with a different color, so we can not
560 // share anything.
561 auto builder = HloComputation::Builder(TestName());
562 auto paramscalar =
563 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
564 auto broadcast = builder.AddInstruction(
565 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
566 auto param0 = builder.AddInstruction(
567 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
568 auto param1 = builder.AddInstruction(
569 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
570 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
571 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
572 auto add = builder.AddInstruction(
573 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
574 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
575 f32vec100_, HloOpcode::kSubtract, add, param1));
576 auto module = CreateNewVerifiedModule();
577 module->AddEntryComputation(builder.Build());
578
579 auto colorer = [](const BufferLiveness& buffer_liveness) {
580 int color = 0;
581
582 for (LogicalBuffer::Id id = 0;
583 id < buffer_liveness.points_to_analysis().num_logical_buffers();
584 id++) {
585 auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
586 buffer.set_color(LogicalBuffer::Color(color++));
587 }
588 return Status::OK();
589 };
590
591 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
592
593 // Distinct input buffers were assigned for parameters.
594 BufferAllocation paramscalar_buffer =
595 GetAssignedInputAllocation(*buffers, paramscalar);
596 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
597 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
598 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
599 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
600 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
601
602 // The mul node has a valid buffer assigned, doesn't share with input.
603 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
604 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
605
606 // The add node can not reuse the mul node's buffer due to coloring.
607 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
608 EXPECT_NE(add_buffer.index(), mul_buffer.index());
609
610 // The sub node has a valid output buffer assigned.
611 GetAssignedOutputAllocation(*buffers, sub);
612 }
613
TEST_F(BufferAssignmentTest,BasicPartiallyColored)614 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
615 // paramscalar ------- (mul) -- (add) -- (sub)
616 // / / /
617 // param0[100] -------/ / /
618 // / /
619 // param1[100] --------------/--------/
620 // The output of the mul and the add have the color 1, and the other buffers
621 // have the color 0, which allows the mul and add to share buffers.
622 auto builder = HloComputation::Builder(TestName());
623 auto paramscalar =
624 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
625 auto broadcast = builder.AddInstruction(
626 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
627 auto param0 = builder.AddInstruction(
628 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
629 auto param1 = builder.AddInstruction(
630 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
631 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
632 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
633 auto add = builder.AddInstruction(
634 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
635 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
636 f32vec100_, HloOpcode::kSubtract, add, param1));
637 auto module = CreateNewVerifiedModule();
638 module->AddEntryComputation(builder.Build());
639
640 auto colorer = [](const BufferLiveness& buffer_liveness) {
641 for (LogicalBuffer::Id id = 0;
642 id < buffer_liveness.points_to_analysis().num_logical_buffers();
643 id++) {
644 auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
645 const auto& aliases =
646 buffer_liveness.points_to_analysis().GetBufferAliases(buffer);
647 for (const auto& alias : aliases) {
648 if (alias.instruction()->opcode() == HloOpcode::kAdd ||
649 alias.instruction()->opcode() == HloOpcode::kMultiply) {
650 buffer.set_color(LogicalBuffer::Color(1));
651 }
652 }
653 if (!buffer.has_color()) {
654 buffer.set_color(LogicalBuffer::Color(0));
655 }
656 }
657 return Status::OK();
658 };
659
660 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
661
662 // Distinct input buffers were assigned for parameters.
663 BufferAllocation paramscalar_buffer =
664 GetAssignedInputAllocation(*buffers, paramscalar);
665 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
666 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
667 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
668 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
669 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
670
671 // The mul node has a valid buffer assigned, doesn't share with input.
672 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
673 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
674
675 // The add node can reuse the mul node's buffer.
676 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
677 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
678
679 // The sub node has a valid output buffer assigned.
680 GetAssignedOutputAllocation(*buffers, sub);
681 }
682
TEST_F(BufferAssignmentTest,MultipleUsersForNode)683 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
684 // This is similar to the Basic test, with the difference that (sub) is
685 // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
686 // (add)'s output.
687 //
688 // paramscalar -------\ /-----------\
689 // \ / \
690 // param0[100] ------- (mul) -- (add) -- (sub)
691 // /
692 // param1[100] ----------------/
693 //
694 auto builder = HloComputation::Builder(TestName());
695 auto paramscalar =
696 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
697 auto broadcast = builder.AddInstruction(
698 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
699 auto param0 = builder.AddInstruction(
700 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
701 auto param1 = builder.AddInstruction(
702 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
703 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
704 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
705 auto add = builder.AddInstruction(
706 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
707 auto sub = builder.AddInstruction(
708 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
709 auto module = CreateNewVerifiedModule();
710 module->AddEntryComputation(builder.Build());
711
712 auto buffers = RunBufferAssignment(module.get());
713
714 // Input buffers were assigned for parameters.
715 BufferAllocation paramscalar_buffer =
716 GetAssignedInputAllocation(*buffers, paramscalar);
717 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
718 BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
719 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
720 EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
721 EXPECT_NE(param0_buffer.index(), param1_index.index());
722
723 // The mul node had a buffer allocated.
724 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
725
726 // Now the add node can't reuse the mul node's buffer.
727 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
728 EXPECT_NE(add_buffer.index(), mul_buffer.index());
729
730 // Log size information for inspection.
731 const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
732 int64 size0 = ValidateBuffers(level0, *buffers);
733 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
734 << " for " << level0.size() << " instructions; "
735 << "total buffer size " << size0;
736 }
737
TEST_F(BufferAssignmentTest,TrivialMap)738 TEST_F(BufferAssignmentTest, TrivialMap) {
739 // This tests a trivial x+1 map as the only operation.
740 //
741 // param0[100x10] ---> (map x+1)
742 //
743 // Builds the map function.
744 auto module = CreateNewVerifiedModule();
745 auto map_computation =
746 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
747 auto inner_last = map_computation->root_instruction();
748
749 // Creates the main kernel and verifies instruction counts.
750 auto builder = HloComputation::Builder(TestName());
751 auto param0 = builder.AddInstruction(
752 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
753 auto map = builder.AddInstruction(
754 HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
755 module->AddEntryComputation(builder.Build());
756
757 const std::vector<const HloInstruction*> level0 = GetInstructions(map);
758 EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
759 const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
760 EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
761
762 // Assigns buffers and fetches sizes.
763 auto buffers = RunBufferAssignment(module.get());
764 int64 size0 = ValidateBuffers(level0, *buffers);
765 int64 size1 = ValidateBuffers(level1, *buffers);
766
767 // Both algorithms assign the map's buffer before processing the embedded
768 // computation, so we can verify that the buffers aren't shared between them
769 // by checking:
770 EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
771 << "Reuse between main kernel and embedded mapping.";
772
773 // An input buffer was assigned for the parameter.
774 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
775
776 // An output buffer was assigned for the map.
777 BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
778 EXPECT_NE(param0_buffer.index(), map_buffer.index());
779
780 // The final computation node of the map is an add of an f32 param and a
781 // constant.
782 EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
783 const BufferAllocation& inner_add_buffer =
784 GetTopLevelAllocation(*buffers, inner_last);
785 EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
786
787 // Log size information for inspection.
788 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
789 << " for " << level0.size() + level1.size() << " instructions; "
790 << "total buffer size " << size0 + size1;
791 }
792
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)793 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
794 // Make sure that the input buffer of a reduce cannot be reused for its
795 // output. (Reuse is not safe in the general case, as it reshapes and some
796 // out-of-order reductions could overwrite an element before a use.)
797 //
798 // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
799 auto module = CreateNewVerifiedModule();
800 auto reduce_computation =
801 module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
802
803 auto builder = HloComputation::Builder(TestName());
804 auto param0 = builder.AddInstruction(
805 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
806 auto exp1 = builder.AddInstruction(
807 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
808 auto exp2 = builder.AddInstruction(
809 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
810 auto const0 = builder.AddInstruction(
811 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
812 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
813 /*shape=*/f32vec10_,
814 /*operand=*/exp2,
815 /*init_value=*/const0,
816 /*dimensions_to_reduce=*/{0}, reduce_computation));
817 auto exp3 = builder.AddInstruction(
818 HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
819
820 module->AddEntryComputation(builder.Build());
821
822 auto buffers = RunBufferAssignment(module.get());
823 const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
824 ValidateBuffers(instrs, *buffers);
825
826 const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
827 const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
828 const BufferAllocation& reduce_buffer =
829 GetTopLevelAllocation(*buffers, reduce);
830
831 // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
832 // checking.
833 EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
834
835 // The buffer of exp2 cannot be used for reduce, even though it's the only
836 // operand.
837 EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
838 }
839
TEST_F(BufferAssignmentTest,ExampleWhile)840 TEST_F(BufferAssignmentTest, ExampleWhile) {
841 // This tests a While loop example from the ir_semantics document.
842 //
843 // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
844 // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
845 //
846 // const3[s32] -------\
847 // const4[f32[4]] --- tuple --- while[condition, body]
848 //
849 // Builds the nested condition and body.
850 auto module = CreateNewVerifiedModule();
851 auto condition_computation =
852 module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
853 auto body_computation =
854 module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
855
856 // Creates the main kernel and verifies instruction counts.
857 auto builder = HloComputation::Builder(TestName());
858 auto const3 = builder.AddInstruction(
859 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
860 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
861 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
862 auto tuple =
863 builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
864 auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
865 t_s32_f32v4_, condition_computation, body_computation, tuple));
866 module->AddEntryComputation(builder.Build());
867
868 const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
869 EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
870 const std::vector<const HloInstruction*> levelc =
871 GetInstructions(condition_computation->root_instruction());
872 EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
873 const std::vector<const HloInstruction*> levelb =
874 GetInstructions(body_computation->root_instruction());
875 EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
876
877 // Assigns buffers and fetches sizes.
878 auto buffers = RunBufferAssignment(module.get());
879 int64 size0 = ValidateBuffers(level0, *buffers);
880 int64 sizec = ValidateBuffers(levelc, *buffers);
881 int64 sizeb = ValidateBuffers(levelb, *buffers);
882
883 // BufferAssignment will assign a single allocation for the following
884 // instructions: while, while.cond.param, while.body.param, while.body.result.
885 EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
886 << "Should be reuse between main kernel and embedded condition.";
887 EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
888 << "Should be reuse between embedded condition and body.";
889 // Expect buffer reuse between main kernel and body computation.
890 EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
891 << "Should be reuse between main kernel and embedded body.";
892
893 // The final computation node of the while body is a tuple of s32 and
894 // f32[4] adds.
895 HloInstruction* body_root = body_computation->root_instruction();
896 EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
897
898 // Check that buffer for each subshape of 'while_op' shares allocation with
899 // corresponding buffer from while body computation at same index.
900 ShapeUtil::ForEachSubshape(
901 while_op->shape(),
902 [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
903 const ShapeIndex& index) {
904 auto while_op_allocation = GetAllocation(*buffers, while_op, index);
905 auto body_root_allocation = GetAllocation(*buffers, body_root, index);
906 EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
907 });
908
909 // Log size information for inspection.
910 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
911 << " for " << level0.size() + levelc.size() + levelb.size()
912 << " instructions; total buffer size " << size0 + sizec + sizeb;
913 }
914
TEST_F(BufferAssignmentTest,ExampleConditional)915 TEST_F(BufferAssignmentTest, ExampleConditional) {
916 auto module = CreateNewVerifiedModule();
917 auto true_computation = module->AddEmbeddedComputation(
918 BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
919 auto false_computation = module->AddEmbeddedComputation(
920 BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
921
922 auto builder = HloComputation::Builder(TestName());
923 auto pred = builder.AddInstruction(
924 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
925 auto const1 = builder.AddInstruction(
926 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
927 auto const2 = builder.AddInstruction(
928 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
929 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
930 r0f32_, pred, const1, true_computation, const2, false_computation));
931 module->AddEntryComputation(builder.Build());
932
933 const std::vector<const HloInstruction*> conditional_instrs =
934 GetInstructions(conditional);
935 const std::vector<const HloInstruction*> true_instrs =
936 GetInstructions(true_computation->root_instruction());
937 const std::vector<const HloInstruction*> false_instrs =
938 GetInstructions(false_computation->root_instruction());
939 EXPECT_EQ(4, conditional_instrs.size());
940 EXPECT_EQ(2, true_instrs.size());
941 EXPECT_EQ(2, false_instrs.size());
942
943 auto buffers = RunBufferAssignment(module.get());
944 ValidateBuffers(conditional_instrs, *buffers);
945 ValidateBuffers(true_instrs, *buffers);
946 ValidateBuffers(false_instrs, *buffers);
947
948 EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
949 << "Should be reuse between conditional and true computation.";
950 EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
951 << "Should be reuse between conditional and false computation.";
952 EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
953 << "Should be reuse between true and false computations.";
954
955 const BufferAllocation& conditional_buffer =
956 GetTopLevelAllocation(*buffers, conditional);
957 const BufferAllocation& true_buffer =
958 GetTopLevelAllocation(*buffers, true_computation->root_instruction());
959 const BufferAllocation& false_buffer =
960 GetTopLevelAllocation(*buffers, false_computation->root_instruction());
961 EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
962 EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
963 }
964
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)965 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
966 // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
967 auto builder = HloComputation::Builder(TestName());
968 auto param0 = builder.AddInstruction(
969 HloInstruction::CreateParameter(0, f32vec100_, "p"));
970 auto exp1 = builder.AddInstruction(
971 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
972 auto tanh = builder.AddInstruction(
973 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
974 auto exp2 = builder.AddInstruction(
975 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
976 auto neg = builder.AddInstruction(
977 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
978
979 auto module = CreateNewVerifiedModule();
980 module->AddEntryComputation(builder.Build());
981 auto assignment = RunBufferAssignment(module.get());
982
983 // tanh and exp2 can reuse exp1's buffer
984 EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
985 auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
986 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
987 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
988 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
989 }
990
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)991 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
992 // This computation is a chain of operations which decreases in buffer size
993 // (via slice) then increases in size (via broadcast):
994 //
995 // param ---> (negate) ---> (slice) ---> (broadcast)
996 //
997 // The negate should share a buffer with broadcast.
998 auto builder = HloComputation::Builder(TestName());
999 auto param0 = builder.AddInstruction(
1000 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1001 auto negate = builder.AddInstruction(
1002 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1003 auto slice = builder.AddInstruction(
1004 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1005 auto broadcast = builder.AddInstruction(
1006 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1007
1008 auto module = CreateNewVerifiedModule();
1009 module->AddEntryComputation(builder.Build());
1010 auto assignment = RunBufferAssignment(module.get());
1011
1012 // negate and broadcast should share a buffer.
1013 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1014 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1015 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1016
1017 // Slice should have its own buffer.
1018 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1019 }
1020
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1021 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1022 // This computation is identical to that in ReuseNonOperandBuffer, but the
1023 // negate value is live until the end of the computation (due to it being an
1024 // operand of the output tuple) preventing reuse.
1025 //
1026 // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1027 // \-----------------------------------/
1028 //
1029 // The negate should not share a buffer with broadcast.
1030 auto builder = HloComputation::Builder(TestName());
1031 auto param0 = builder.AddInstruction(
1032 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1033 auto negate = builder.AddInstruction(
1034 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1035 auto slice = builder.AddInstruction(
1036 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1037 auto broadcast = builder.AddInstruction(
1038 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1039 builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1040
1041 auto module = CreateNewVerifiedModule();
1042 module->AddEntryComputation(builder.Build());
1043 auto assignment = RunBufferAssignment(module.get());
1044
1045 // The instructions should not share buffers.
1046 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1047 GetTopLevelAllocation(*assignment, negate));
1048 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1049 GetTopLevelAllocation(*assignment, slice));
1050 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1051 GetTopLevelAllocation(*assignment, slice));
1052 }
1053
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1054 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1055 // This computation is identical to that in ReuseNonOperandBuffer, but the
1056 // negate value is placed into a tuple which lives to the end of the
1057 // computation. This extends the live range of negate's buffer preventing
1058 // reuse due to buffer aliasing.
1059 //
1060 // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1061 // \-----------------------------------/
1062 //
1063 // The negate should not share a buffer with broadcast.
1064 auto builder = HloComputation::Builder(TestName());
1065 auto param0 = builder.AddInstruction(
1066 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1067 auto negate = builder.AddInstruction(
1068 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1069 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1070 auto tuple_element = builder.AddInstruction(
1071 HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1072 auto slice = builder.AddInstruction(
1073 HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1074 auto broadcast = builder.AddInstruction(
1075 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1076 builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1077
1078 auto module = CreateNewVerifiedModule();
1079 module->AddEntryComputation(builder.Build());
1080 auto assignment = RunBufferAssignment(module.get());
1081
1082 // The instructions should not share buffers.
1083 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1084 GetTopLevelAllocation(*assignment, negate));
1085 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1086 GetTopLevelAllocation(*assignment, slice));
1087 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1088 GetTopLevelAllocation(*assignment, slice));
1089 }
1090
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1091 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1092 // This computation is very similar to ReuseNonOperandBuffer except the
1093 // broadcast has a smaller output than the negate. This should block reuse of
1094 // negate's buffer by broadcast because the output buffer(s) of a computation
1095 // should be exactly sized for the value.
1096 //
1097 // param ---> (negate) ---> (slice) ---> (broadcast)
1098 //
1099 // Neither negate nor slice may share a buffer with broadcast.
1100 auto builder = HloComputation::Builder(TestName());
1101 auto param0 = builder.AddInstruction(
1102 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1103 // Negate output is 100 elements.
1104 auto negate = builder.AddInstruction(
1105 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1106 // Slice output is 10 elements.
1107 auto slice = builder.AddInstruction(
1108 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1109 // Broadcast output is 40 elements.
1110 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1111 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1112
1113 auto module = CreateNewVerifiedModule();
1114 module->AddEntryComputation(builder.Build());
1115 auto assignment = RunBufferAssignment(module.get());
1116
1117 // The broadcast output buffer cannot be shared.
1118 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1119 GetTopLevelAllocation(*assignment, negate));
1120 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1121 GetTopLevelAllocation(*assignment, slice));
1122 }
1123
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1124 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1125 // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1126 // output is exactly the same size as the negate (rather than being
1127 // smaller). This enables reuse of negate's buffer by the broadcast because
1128 // the output buffer will be sized exactly to its value.
1129 //
1130 // param ---> (negate) ---> (slice) ---> (broadcast)
1131 //
1132 // The negate should *not* share a buffer with broadcast.
1133 auto builder = HloComputation::Builder(TestName());
1134 auto param0 = builder.AddInstruction(
1135 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1136 // Negate output is 100 elements.
1137 auto negate = builder.AddInstruction(
1138 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1139 auto slice = builder.AddInstruction(
1140 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1141 // Broadcast output is 40 elements.
1142 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1143 ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1144
1145 auto module = CreateNewVerifiedModule();
1146 module->AddEntryComputation(builder.Build());
1147 auto assignment = RunBufferAssignment(module.get());
1148
1149 // negate and broadcast should share a buffer.
1150 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1151 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1152 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1153
1154 // Slice should have its own buffer.
1155 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1156 }
1157
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1158 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1159 // This computation is very similar to ReuseNonOperandBuffer except the
1160 // broadcast has a smaller output than the negate, and the broadcast is
1161 // contained in the computation output as a tuple element. This should block
1162 // reuse of the negate's buffer by the broadcast because the output buffer(s)
1163 // of a computation should be exactly sized for the value. This includes those
1164 // buffers aliased in the output (eg, contained as tuple elements).
1165 //
1166 // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1167 //
1168 // Neither negate nor slice may share a buffer with broadcast.
1169 auto builder = HloComputation::Builder(TestName());
1170 auto param0 = builder.AddInstruction(
1171 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1172 // Negate output is 100 elements.
1173 auto negate = builder.AddInstruction(
1174 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1175 // Slice output is 10 elements.
1176 auto slice = builder.AddInstruction(
1177 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1178 // Broadcast output is 40 elements.
1179 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1180 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1181 builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1182
1183 auto module = CreateNewVerifiedModule();
1184 module->AddEntryComputation(builder.Build());
1185 auto assignment = RunBufferAssignment(module.get());
1186
1187 // The broadcast output buffer cannot be shared.
1188 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1189 GetTopLevelAllocation(*assignment, negate));
1190 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1191 GetTopLevelAllocation(*assignment, slice));
1192 }
1193
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1194 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1195 // Verify that buffers for embedded computations are properly marked as
1196 // thread-local and that embedded parameters are not marked as
1197 // is_entry_computation_parameter.
1198 auto module = CreateNewVerifiedModule();
1199 auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1200 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1201
1202 // Create a scalar computation to use in a map.
1203 auto map_builder = HloComputation::Builder(TestName() + "_map");
1204 auto map_param = map_builder.AddInstruction(
1205 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1206 auto map_root = map_builder.AddInstruction(
1207 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1208 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1209
1210 // Create a vector computation to use in a kCall.
1211 auto call_builder = HloComputation::Builder(TestName() + "_call");
1212 auto call_param = call_builder.AddInstruction(
1213 HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1214 auto call_root = call_builder.AddInstruction(
1215 HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1216 auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1217
1218 // Create entry computation which kCalls call_computation and then calls map
1219 // with map_computation on the result.
1220 auto builder = HloComputation::Builder(TestName());
1221 auto param = builder.AddInstruction(
1222 HloInstruction::CreateParameter(0, vec_shape, "param"));
1223 auto call = builder.AddInstruction(
1224 HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1225 auto map = builder.AddInstruction(
1226 HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1227 module->AddEntryComputation(builder.Build());
1228
1229 auto assignment = RunBufferAssignment(module.get());
1230
1231 // Allocations for the map computation should be thread-local and not
1232 // live-out.
1233 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1234 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1235 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1236 EXPECT_TRUE(map_param_alloc.is_thread_local());
1237
1238 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1239 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1240 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1241 EXPECT_TRUE(map_root_alloc.is_thread_local());
1242
1243 // Allocations for the call computation should not be thread-local.
1244 auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1245 EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1246 EXPECT_FALSE(call_param_alloc.maybe_live_out());
1247 EXPECT_FALSE(call_param_alloc.is_thread_local());
1248
1249 auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1250 EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1251 EXPECT_FALSE(call_root_alloc.is_thread_local());
1252
1253 // Entry computation allocations can be marked liveout and
1254 // is_entry_computation_parameter.
1255 auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1256 EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1257 EXPECT_FALSE(param_alloc.maybe_live_out());
1258 EXPECT_FALSE(param_alloc.is_thread_local());
1259
1260 auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1261 EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1262 EXPECT_TRUE(map_alloc.maybe_live_out());
1263 EXPECT_FALSE(map_alloc.is_thread_local());
1264 }
1265
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1266 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1267 // Test a computation that returns a tuple parameter.
1268 auto builder = HloComputation::Builder(TestName());
1269 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1270 0,
1271 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1272 ShapeUtil::MakeShape(F32, {}),
1273 ShapeUtil::MakeShape(S32, {42})}),
1274 "param0"));
1275
1276 auto module = CreateNewVerifiedModule();
1277 module->AddEntryComputation(builder.Build());
1278 auto assignment = RunBufferAssignment(module.get());
1279
1280 // There should be four allocations: one for vector of pointers, and one for
1281 // each tuple element.
1282 EXPECT_EQ(4, assignment->Allocations().size());
1283
1284 // Verify each buffer allocation is marked as an entry computation parameter
1285 // and is liveout.
1286 ShapeUtil::ForEachSubshape(
1287 tuple_param->shape(),
1288 [this, &assignment, tuple_param](const Shape& /*subshape*/,
1289 const ShapeIndex& index) {
1290 auto allocation = GetAllocation(*assignment, tuple_param, index);
1291 EXPECT_TRUE(allocation.is_entry_computation_parameter());
1292 EXPECT_EQ(0, allocation.parameter_number());
1293 EXPECT_TRUE(allocation.maybe_live_out());
1294 });
1295 }
1296
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1297 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1298 // Test a computation which returns a GetElementTuple of a nested tuple
1299 // parameter.
1300 auto builder = HloComputation::Builder(TestName());
1301 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1302 0,
1303 ShapeUtil::MakeTupleShape(
1304 {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1305 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1306 ShapeUtil::MakeShape(S32, {101})})}),
1307 "param0"));
1308 auto tuple_element =
1309 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1310 ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1311
1312 auto module = CreateNewVerifiedModule();
1313 module->AddEntryComputation(builder.Build());
1314 auto assignment = RunBufferAssignment(module.get());
1315
1316 // Only some of the elements of the input param are liveout.
1317 EXPECT_FALSE(
1318 GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1319 // Tuple element at index={1} is live out because GetTupleElement({1})
1320 // forwards a pointer to this allocation (instead of defining its own buffer).
1321 EXPECT_TRUE(
1322 GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1323 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1324 .maybe_live_out());
1325 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1326 .maybe_live_out());
1327
1328 // The GetTupleElement output is liveout.
1329 EXPECT_TRUE(
1330 GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1331
1332 // Verify that the GetTupleElement allocations of its elements match the
1333 // corresponding tuple parameter allocations because they alias.
1334 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1335 GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1336 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1337 GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1338
1339 // GetTupleElement forwards a pointer to its underlying buffer, so verify
1340 // that it has the same allocation than the corresponding parameter element.
1341 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1342 GetTopLevelAllocation(*assignment, tuple_element));
1343 }
1344
1345 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1346 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1347 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1348 // Test that a tuple constant which is forwarded to the computation output
1349 // is properly handled.
1350 auto builder = HloComputation::Builder(TestName());
1351 Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1352 LiteralUtil::CreateR0<int64>(1)};
1353 builder.AddInstruction(HloInstruction::CreateConstant(
1354 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1355
1356 auto module = CreateNewVerifiedModule();
1357 module->AddEntryComputation(builder.Build());
1358 auto assignment = RunBufferAssignment(module.get());
1359
1360 EXPECT_EQ(3, assignment->Allocations().size());
1361 }
1362
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1363 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1364 // Test a computation which returns a tuple custom call value.
1365 auto builder = HloComputation::Builder(TestName());
1366 auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1367 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1368 ShapeUtil::MakeShape(S32, {101})}),
1369 /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1370 auto module = CreateNewVerifiedModule();
1371 module->AddEntryComputation(builder.Build());
1372 auto assignment = RunBufferAssignment(module.get());
1373
1374 EXPECT_EQ(3, assignment->Allocations().size());
1375 EXPECT_TRUE(
1376 GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1377 EXPECT_TRUE(
1378 GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1379 EXPECT_TRUE(
1380 GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1381 }
1382
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1383 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1384 // Test a computation which returns a tuple call value.
1385 auto module = CreateNewVerifiedModule();
1386 auto elem_shape = f32vec4_;
1387 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1388
1389 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1390 auto sub_param = sub_builder.AddInstruction(
1391 HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1392 auto sub_tuple =
1393 sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1394 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1395
1396 auto builder = HloComputation::Builder(TestName());
1397 auto param = builder.AddInstruction(
1398 HloInstruction::CreateParameter(0, elem_shape, "param"));
1399 auto call = builder.AddInstruction(
1400 HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1401 module->AddEntryComputation(builder.Build());
1402
1403 auto assignment = RunBufferAssignment(module.get());
1404
1405 EXPECT_EQ(2, assignment->Allocations().size());
1406 // Buffers for call are colocated with the sub-computation.
1407 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1408 GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1409 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1410 GetAllocation(*assignment, sub_param, /*index=*/{}));
1411
1412 // The parameter isn't aliased with the result tuple, but it is aliased with
1413 // the call operand.
1414 EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1415 GetTopLevelAllocation(*assignment, sub_tuple));
1416 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1417 GetTopLevelAllocation(*assignment, sub_param));
1418 }
1419
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1420 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1421 // Test a chain of calls with tuple output. The chain looks like:
1422 // A: call(B, tuple(param))
1423 // B: call(C, param)
1424 // C: call(D, param)
1425 // D: param
1426 auto module = CreateNewVerifiedModule();
1427 auto elem_shape = f32vec4_;
1428 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1429
1430 auto d_builder = HloComputation::Builder(TestName() + "_d");
1431 auto d_param = d_builder.AddInstruction(
1432 HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1433 auto d_computation = d_builder.Build();
1434
1435 auto c_builder = HloComputation::Builder(TestName() + "_c");
1436 auto c_param = c_builder.AddInstruction(
1437 HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1438 auto c_call = c_builder.AddInstruction(
1439 HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1440 auto c_computation = c_builder.Build();
1441
1442 auto b_builder = HloComputation::Builder(TestName() + "_b");
1443 auto b_param = b_builder.AddInstruction(
1444 HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1445 auto b_call = b_builder.AddInstruction(
1446 HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1447 auto b_computation = b_builder.Build();
1448
1449 auto a_builder = HloComputation::Builder(TestName());
1450 auto a_param = a_builder.AddInstruction(
1451 HloInstruction::CreateParameter(0, elem_shape, "param"));
1452 auto a_tuple =
1453 a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1454 auto a_call = a_builder.AddInstruction(
1455 HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1456 auto a_computation = a_builder.Build();
1457
1458 // Add the computations in an order that doesn't match the dependency
1459 // post-order, to shake out more possible bugs.
1460 module->AddEmbeddedComputation(std::move(d_computation));
1461 module->AddEmbeddedComputation(std::move(c_computation));
1462 module->AddEntryComputation(std::move(a_computation));
1463 module->AddEmbeddedComputation(std::move(b_computation));
1464
1465 auto assignment = RunBufferAssignment(module.get());
1466
1467 // Buffers for call are colocated with the sub-computations.
1468 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1469 GetAllocation(*assignment, b_call, /*index=*/{}));
1470 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1471 GetAllocation(*assignment, c_call, /*index=*/{}));
1472 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1473 GetAllocation(*assignment, d_param, /*index=*/{}));
1474 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1475 GetAllocation(*assignment, b_call, /*index=*/{0}));
1476 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1477 GetAllocation(*assignment, c_call, /*index=*/{0}));
1478 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1479 GetAllocation(*assignment, d_param, /*index=*/{0}));
1480
1481 EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1482 EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1483 EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1484
1485 EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1486 GetAllocation(*assignment, c_param, /*index=*/{0}));
1487 EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1488 GetAllocation(*assignment, d_param, /*index=*/{0}));
1489 }
1490
TEST_F(BufferAssignmentTest,BitcastAsOutput)1491 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1492 // Test a computation which returns a bitcast value.
1493 auto builder = HloComputation::Builder(TestName());
1494 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1495 0, ShapeUtil::MakeShape(F32, {42}), "param"));
1496 auto bitcast = builder.AddInstruction(
1497 HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param));
1498
1499 auto module = CreateNewVerifiedModule();
1500 module->AddEntryComputation(builder.Build());
1501 auto assignment = RunBufferAssignment(module.get());
1502
1503 // Bitcast should get the same allocation as the param.
1504 EXPECT_EQ(1, assignment->Allocations().size());
1505 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1506 GetTopLevelAllocation(*assignment, bitcast));
1507 }
1508
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1509 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1510 // Test a computation with an output that has an ambiguous points-to set.
1511 // This is constructed using a select among tuple shapes.
1512 auto builder = HloComputation::Builder(TestName());
1513 auto tuple_shape =
1514 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1515
1516 auto tuple_param0 = builder.AddInstruction(
1517 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1518 auto tuple_param1 = builder.AddInstruction(
1519 HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1520 auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1521 2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1522 auto select = builder.AddInstruction(
1523 HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1524 pred_param, tuple_param0, tuple_param1));
1525
1526 auto module = CreateNewVerifiedModule();
1527 module->AddEntryComputation(builder.Build());
1528 auto assignment = RunBufferAssignment(module.get());
1529
1530 // Select shallow copies one of its operands so it defines its own top-level
1531 // buffer and receives its own allocation.
1532 auto select_alloc = GetTopLevelAllocation(*assignment, select);
1533 EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1534 EXPECT_EQ(select,
1535 select_alloc.assigned_buffers().begin()->first->instruction());
1536
1537 // The buffer for the tuple element of the select is forwarded from one its
1538 // operands which cannot be determined statically. Therefore its slices
1539 // should include the slices of both of the elements in the parameters.
1540 auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1541 EXPECT_EQ(2, element_slices.size());
1542 EXPECT_THAT(element_slices,
1543 UnorderedElementsAre(
1544 assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1545 .ConsumeValueOrDie(),
1546 assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1547 .ConsumeValueOrDie()));
1548 }
1549
1550 // TODO(b/34669761): Remove this test when buffers are allowed to share
1551 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1552 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1553 // Test a computation that returns a tuple parameter.
1554 auto builder = HloComputation::Builder(TestName());
1555 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1556 auto param = builder.AddInstruction(
1557 HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1558 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1559 auto tuple_element = builder.AddInstruction(
1560 HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1561 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1562 scalar_shape, HloOpcode::kCopy, tuple_element));
1563
1564 auto module = CreateNewVerifiedModule();
1565 module->AddEntryComputation(builder.Build());
1566 auto assignment = RunBufferAssignment(module.get());
1567
1568 // There should be no buffer reuse. The copy should not reuse the tuple
1569 // buffer.
1570 EXPECT_EQ(3, assignment->Allocations().size());
1571 EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1572 GetTopLevelAllocation(*assignment, copy));
1573 }
1574
TEST_F(BufferAssignmentTest,OneTempAllocation)1575 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1576 // Test a computation that requires multiple temp buffers, and ensure they
1577 // are combined into a single allocation.
1578 auto builder = HloComputation::Builder(TestName());
1579 Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1580 Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1581 Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1582 Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1583 Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1584
1585 // There should be separate temp buffers for dot_ab and dot_bc.
1586 auto param_a = builder.AddInstruction(
1587 HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1588 auto param_b = builder.AddInstruction(
1589 HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1590 auto param_c = builder.AddInstruction(
1591 HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1592 DotDimensionNumbers dot_dnums;
1593 dot_dnums.add_lhs_contracting_dimensions(1);
1594 dot_dnums.add_rhs_contracting_dimensions(0);
1595 PrecisionConfig precision_config;
1596 precision_config.mutable_operand_precision()->Resize(
1597 2, PrecisionConfig::DEFAULT);
1598 auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1599 shape_2x4, param_a, param_b, dot_dnums, precision_config));
1600 auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1601 shape_3x4, param_b, param_c, dot_dnums, precision_config));
1602 builder.AddInstruction(
1603 HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1604
1605 // Run buffer assignment with alignment=1.
1606 auto module = CreateNewVerifiedModule();
1607 module->AddEntryComputation(builder.Build());
1608 auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1609
1610 // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1611 EXPECT_EQ(5, assignment->Allocations().size());
1612
1613 // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1614 // and each occupies different slices of that allocation.
1615 BufferAllocation::Slice slice_ab =
1616 assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1617 BufferAllocation::Slice slice_bc =
1618 assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1619 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1620 EXPECT_NE(slice_ab, slice_bc);
1621 EXPECT_EQ(32, slice_ab.size());
1622 EXPECT_EQ(48, slice_bc.size());
1623 EXPECT_EQ(80, slice_ab.allocation()->size());
1624 EXPECT_EQ(80, slice_bc.allocation()->size());
1625
1626 // Re-run buffer assignment with alignment=64.
1627 assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1628 EXPECT_EQ(5, assignment->Allocations().size());
1629 slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1630 slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1631 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1632 EXPECT_NE(slice_ab, slice_bc);
1633 EXPECT_EQ(32, slice_ab.size());
1634 EXPECT_EQ(48, slice_bc.size());
1635 // Ensure the offsets and allocation size account for the alignment, without
1636 // assuming which buffer gets assigned first.
1637 if (slice_ab.offset() == 0) {
1638 EXPECT_EQ(64, slice_bc.offset());
1639 EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1640 EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1641 } else {
1642 EXPECT_EQ(64, slice_ab.offset());
1643 EXPECT_EQ(0, slice_bc.offset());
1644 EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1645 EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1646 }
1647 }
1648
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1649 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1650 // paramscalar ------- (mul) -- (add) -- (sub)
1651 // / / /
1652 // param0[100] -------/ / /
1653 // / /
1654 // param1[100] --------------/--------/
1655 auto builder = HloComputation::Builder(TestName());
1656 auto paramscalar =
1657 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1658 auto broadcast = builder.AddInstruction(
1659 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1660 auto param0 = builder.AddInstruction(
1661 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1662 auto param1 = builder.AddInstruction(
1663 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1664 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1665 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1666 auto add = builder.AddInstruction(
1667 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1668 builder.AddInstruction(HloInstruction::CreateBinary(
1669 f32vec100_, HloOpcode::kSubtract, add, param1));
1670 auto module = CreateNewVerifiedModule();
1671 module->AddEntryComputation(builder.Build());
1672
1673 auto buffers = RunBufferAssignment(module.get());
1674
1675 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1676 const std::vector<const LogicalBuffer*>& peak_buffers =
1677 mul_buffer.PeakMemoryLogicalBuffers();
1678 ASSERT_EQ(peak_buffers.size(), 1);
1679 EXPECT_EQ(peak_buffers[0]->instruction(), broadcast);
1680 }
1681
TEST_F(BufferAssignmentTest,PeakBuffers)1682 TEST_F(BufferAssignmentTest, PeakBuffers) {
1683 // Compute the peak liveness buffers of the following sequence:
1684 //
1685 // %param = ...
1686 // %log = log(%param)
1687 // %rev = reverse(%log)
1688 // %neg = neg(%param)
1689 // %concat = concat(%rev, %neg)
1690 // ROOT %root = slice(concat)
1691 //
1692 // In the temporary block, the set of live buffers at peak memory use should
1693 // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1694 auto builder = HloComputation::Builder(TestName());
1695 auto param = builder.AddInstruction(
1696 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1697 auto log = builder.AddInstruction(
1698 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1699 auto rev = builder.AddInstruction(
1700 HloInstruction::CreateReverse(f32vec100_, log, {0}));
1701 auto neg = builder.AddInstruction(
1702 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1703 const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1704 auto concat = builder.AddInstruction(
1705 HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1706 // Make the root tiny so no interior nodes can share its buffer.
1707 auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1708
1709 ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1710
1711 auto module = CreateNewVerifiedModule();
1712 module->AddEntryComputation(builder.Build());
1713
1714 auto buffers = RunBufferAssignmentWithInstructionSequence(
1715 module.get(), {param, log, rev, neg, concat, root});
1716
1717 // The temporary buffer should hold the 4 interior instructions.
1718 const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1719 EXPECT_FALSE(buffer.IsInputOrOutput());
1720 EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1721 ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1722
1723 const std::vector<const LogicalBuffer*>& peak_buffers =
1724 buffer.PeakMemoryLogicalBuffers();
1725
1726 // The peak live set should be concat and its inputs.
1727 ASSERT_EQ(peak_buffers.size(), 3);
1728 std::vector<const HloInstruction*> peak_instructions;
1729 for (const LogicalBuffer* logical_buffer : peak_buffers) {
1730 peak_instructions.push_back(logical_buffer->instruction());
1731 }
1732 EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1733 }
1734
TEST_F(BufferAssignmentTest,PeakBuffersWhile)1735 TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
1736 auto module = CreateNewVerifiedModule();
1737 const Shape shape = ShapeUtil::MakeShape(F32, {123, 123});
1738 HloComputation* condition;
1739 {
1740 auto b = HloComputation::Builder(TestName() + ".cond");
1741 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
1742 b.AddInstruction(
1743 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1744 condition = module->AddEmbeddedComputation(b.Build());
1745 }
1746 HloComputation* body;
1747 {
1748 auto b = HloComputation::Builder(TestName() + ".body");
1749 auto param =
1750 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
1751 b.AddInstruction(
1752 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
1753 body = module->AddEmbeddedComputation(b.Build());
1754 }
1755 auto builder = HloComputation::Builder(TestName());
1756 auto param =
1757 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1758 auto copy = builder.AddInstruction(
1759 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param));
1760 auto while_op = builder.AddInstruction(
1761 HloInstruction::CreateWhile(shape, condition, body, copy));
1762 // This broadcast should get a temporary allocation which is merged with the
1763 // allocation for the while. Peak buffers should include the while and the
1764 // broadcast.
1765 auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1766 ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1}));
1767 builder.AddInstruction(HloInstruction::CreateReverse(
1768 ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0}));
1769 module->AddEntryComputation(builder.Build());
1770
1771 auto buffers = RunBufferAssignment(module.get());
1772 const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast);
1773 const std::vector<const LogicalBuffer*>& peak_buffers =
1774 buffer.PeakMemoryLogicalBuffers();
1775 ASSERT_EQ(peak_buffers.size(), 2);
1776
1777 // The peak buffers should include the broadcast and one of the colocated
1778 // buffers of the while (body param, condition param, body root, or the while
1779 // itself).
1780 const LogicalBuffer* bcast_buffer;
1781 const LogicalBuffer* nonbcast_buffer;
1782 if (peak_buffers[0]->instruction() == bcast) {
1783 bcast_buffer = peak_buffers[0];
1784 nonbcast_buffer = peak_buffers[1];
1785 } else {
1786 bcast_buffer = peak_buffers[1];
1787 nonbcast_buffer = peak_buffers[0];
1788 }
1789 EXPECT_EQ(bcast_buffer->instruction(), bcast);
1790 EXPECT_TRUE(
1791 nonbcast_buffer->instruction() == copy ||
1792 nonbcast_buffer->instruction() == while_op ||
1793 nonbcast_buffer->instruction() == body->parameter_instruction(0) ||
1794 nonbcast_buffer->instruction() == body->root_instruction() ||
1795 nonbcast_buffer->instruction() == condition->parameter_instruction(0));
1796 }
1797
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1798 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1799 const char* hlo_text = R"(
1800 HloModule Module
1801
1802 True {
1803 ROOT x.0.1 = f32[] parameter(0)
1804 }
1805
1806 False {
1807 x.0.0 = f32[] parameter(0)
1808 ROOT copy.1 = f32[] copy(x.0.0)
1809 }
1810
1811 ENTRY main {
1812 pred.1.0 = pred[] parameter(0)
1813 constant.1.1 = f32[] constant(56)
1814 copy.2 = f32[] copy(constant.1.1)
1815 constant.1.2 = f32[] constant(12)
1816 ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
1817 true_computation=True, false_computation=False
1818 }
1819 )";
1820
1821 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1822 HloInstruction* constant_1 =
1823 m->entry_computation()->GetInstructionWithName("constant.1.1");
1824 HloInstruction* constant_2 =
1825 m->entry_computation()->GetInstructionWithName("constant.1.2");
1826
1827 auto buffers = RunBufferAssignment(m.get());
1828
1829 {
1830 const BufferAllocation& allocation_for_const_1 =
1831 GetTopLevelAllocation(*buffers, constant_1);
1832 EXPECT_TRUE(allocation_for_const_1.is_constant());
1833 for (const auto& buffer_offset_pair :
1834 allocation_for_const_1.assigned_buffers()) {
1835 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1836 HloOpcode::kCopy);
1837 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1838 HloOpcode::kConditional);
1839 }
1840 }
1841
1842 {
1843 const BufferAllocation& allocation_for_const_2 =
1844 GetTopLevelAllocation(*buffers, constant_2);
1845 EXPECT_TRUE(allocation_for_const_2.is_constant());
1846 for (const auto& buffer_offset_pair :
1847 allocation_for_const_2.assigned_buffers()) {
1848 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1849 HloOpcode::kCopy);
1850 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1851 HloOpcode::kConditional);
1852 }
1853 }
1854 }
1855
1856 class WhileBufferAssignmentTest : public HloTestBase {
1857 protected:
BuildWhileConditionComputation(const string & name)1858 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
1859 const string& name) {
1860 auto builder = HloComputation::Builder(name);
1861 builder.AddInstruction(
1862 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
1863 auto zero = builder.AddInstruction(
1864 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1865 auto ten = builder.AddInstruction(
1866 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
1867 builder.AddInstruction(HloInstruction::CreateCompare(
1868 ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
1869 return builder.Build();
1870 }
1871
BuildWhileBodyComputation(const string & name)1872 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
1873 const string& name) {
1874 auto builder = HloComputation::Builder(name);
1875 auto loop_state = builder.AddInstruction(
1876 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
1877 auto input = builder.AddInstruction(
1878 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
1879 auto weights = builder.AddInstruction(
1880 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
1881 auto output = builder.AddInstruction(HloInstruction::CreateBinary(
1882 data_shape_, HloOpcode::kMultiply, input, weights));
1883 builder.AddInstruction(
1884 HloInstruction::CreateTuple({input, weights, output}));
1885 return builder.Build();
1886 }
1887
RunBufferAssignment(HloModule * module,int64 alignment=1)1888 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
1889 int64 alignment = 1) {
1890 HloSchedule schedule =
1891 ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
1892 return BufferAssigner::Run(
1893 module, absl::make_unique<SequentialHloOrdering>(schedule),
1894 ByteSizeOf,
1895 [alignment](LogicalBuffer::Color) { return alignment; },
1896 /*allow_input_output_aliasing=*/false,
1897 /*allocate_buffers_for_constants=*/true)
1898 .ConsumeValueOrDie();
1899 }
1900
ByteSizeOf(const BufferValue & buffer)1901 static int64 ByteSizeOf(const BufferValue& buffer) {
1902 return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
1903 }
1904
1905 Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
1906 Shape loop_state_shape_ =
1907 ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
1908 };
1909
RunCopyInsertion(HloModule * module)1910 static void RunCopyInsertion(HloModule* module) {
1911 CopyInsertion copy_insertion;
1912 EXPECT_IS_OK(copy_insertion.Run(module).status());
1913 }
1914
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)1915 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
1916 auto module = CreateNewVerifiedModule();
1917 auto builder = HloComputation::Builder("entry");
1918
1919 auto input0 = builder.AddInstruction(
1920 HloInstruction::CreateParameter(0, data_shape_, "input0"));
1921 auto weights0 = builder.AddInstruction(
1922 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
1923 auto weights1 = builder.AddInstruction(
1924 HloInstruction::CreateParameter(2, data_shape_, "weights1"));
1925
1926 auto zero = builder.AddInstruction(
1927 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1928 auto output0 = builder.AddInstruction(
1929 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
1930 auto output1 = builder.AddInstruction(
1931 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
1932
1933 auto cond0 =
1934 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
1935 auto body0 =
1936 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
1937
1938 auto tuple0 = builder.AddInstruction(
1939 HloInstruction::CreateTuple({input0, weights0, output0}));
1940 auto while0 = builder.AddInstruction(
1941 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
1942
1943 auto cond1 =
1944 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
1945 auto body1 =
1946 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
1947 auto input1 = builder.AddInstruction(
1948 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
1949 auto tuple1 = builder.AddInstruction(
1950 HloInstruction::CreateTuple({input1, weights1, output1}));
1951 auto while1 = builder.AddInstruction(
1952 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
1953
1954 module->AddEntryComputation(builder.Build());
1955 RunCopyInsertion(module.get());
1956 auto assignment = RunBufferAssignment(module.get());
1957
1958 // Verify 'input0' and read-only use while0{0} alias.
1959 EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
1960 assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
1961 // Verify 'weights0' and read-only use while0{1} alias.
1962 EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
1963 assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
1964 // Verify 'while0{2}' and read-only use while1{0} alias.
1965 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
1966 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
1967 // Verify 'weights1' and read-only use while1{1} alias.
1968 EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
1969 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
1970 }
1971
1972 // Tests that two colocated buffer sets are not merged if an entry parameter
1973 // buffer belongs to either of the colocation sets (b/73267882).
1974 //
1975 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
1976 //
1977 // %while.0 body just forwards the init value, so the loop carried variable
1978 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)1979 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
1980 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
1981
1982 const char* module_str = R"(
1983 HloModule test_module
1984
1985 %cond.v0 {
1986 %param = s32[] parameter(0)
1987 ROOT %constant = pred[] constant(true)
1988 }
1989
1990 %cond.v1 {
1991 %param.0 = s32[] parameter(0)
1992 ROOT %constant.0 = pred[] constant(true)
1993 }
1994
1995 %body.v0 {
1996 ROOT %param.1 = s32[] parameter(0)
1997 }
1998
1999 %body.v1 {
2000 %param.2 = s32[] parameter(0)
2001 ROOT add = s32[] add(%param.2, %param.2)
2002 }
2003
2004 ENTRY %test_module {
2005 %param.3 = s32[] parameter(0)
2006 %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2007 %mul = s32[] multiply(%while.0, %while.0)
2008 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2009 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2010 })";
2011
2012 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2013
2014 // Run CopyInsertion and check if the graph constructed above doesn't need
2015 // any copies inserted for BufferAssignment to run.
2016 int64 instruction_count = m->instruction_count();
2017 CopyInsertion copy_insertion;
2018 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2019 ASSERT_EQ(instruction_count, m->instruction_count());
2020
2021 // Get the instructions in the module.
2022 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2023 const HloInstruction* param =
2024 m->entry_computation()->parameter_instruction(0);
2025 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2026 const HloInstruction* while1 = bcast->operand(0);
2027 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2028 const HloInstruction* while0 = while1->operand(0)->operand(0);
2029 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2030
2031 // Run buffer assignment.
2032 auto assignment = RunBufferAssignment(m.get());
2033 TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2034 assignment->GetUniqueSlice(param, {}));
2035 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2036 assignment->GetUniqueSlice(while0, {}));
2037 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2038 assignment->GetUniqueSlice(while1, {}));
2039
2040 // The parameter slice is part of the while0's colocation set (init value),
2041 // but not merged into the while1's colocation set.
2042 EXPECT_EQ(slice_param, slice_while0);
2043 EXPECT_NE(slice_param, slice_while1);
2044 }
2045
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2046 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2047 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2048
2049 const char* module_str = R"(
2050 HloModule test_module
2051
2052 %cond.v0 {
2053 %param = s32[] parameter(0)
2054 ROOT %constant = pred[] constant(true)
2055 }
2056
2057 %cond.v1 {
2058 %param.0 = s32[] parameter(0)
2059 ROOT %constant.0 = pred[] constant(true)
2060 }
2061
2062 %body.v0 {
2063 ROOT %param.1 = s32[] parameter(0)
2064 }
2065
2066 %body.v1 {
2067 %param.2 = s32[] parameter(0)
2068 ROOT add = s32[] add(%param.2, %param.2)
2069 }
2070
2071 ENTRY %test_module {
2072 %constant.42 = s32[] constant(42)
2073 %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2074 %mul = s32[] multiply(%while.0, %while.0)
2075 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2076 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2077 })";
2078
2079 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2080
2081 // Run CopyInsertion and check if the graph constructed above doesn't need
2082 // any copies inserted for BufferAssignment to run.
2083 int64 instruction_count = m->instruction_count();
2084 CopyInsertion copy_insertion;
2085 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2086 ASSERT_EQ(instruction_count, m->instruction_count());
2087
2088 // Get the instructions in the module.
2089 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2090 const HloInstruction* constant =
2091 m->entry_computation()->GetInstructionWithName("constant.42");
2092 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2093 const HloInstruction* while1 = bcast->operand(0);
2094 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2095 const HloInstruction* while0 = while1->operand(0)->operand(0);
2096 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2097
2098 // Run buffer assignment.
2099 auto assignment = RunBufferAssignment(m.get());
2100 TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2101 assignment->GetUniqueSlice(constant, {}));
2102 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2103 assignment->GetUniqueSlice(while0, {}));
2104 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2105 assignment->GetUniqueSlice(while1, {}));
2106
2107 // The constant slice is part of the while0's colocation set (init value), but
2108 // not merged into the while1's colocation set.
2109 EXPECT_EQ(slice_constant, slice_while0);
2110 EXPECT_NE(slice_constant, slice_while1);
2111 }
2112
2113 // Tests that the colocated buffers for while instructions are properly assigned
2114 // during buffer assignment such that the result tuple elements are not assigned
2115 // to the same buffer.
2116 //
2117 // %infeed --> %while.0 --> %while.1 --+
2118 // +-- %tuple
2119 // %zero --> %add --> %while.2 --+
2120 //
2121 // Execution Order:
2122 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2123 //
2124 // The HLO computation used in this test requires specific ordering to expose
2125 // the bug (b/72496031). During buffer assignment, the visitation order of
2126 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2127 // assignment was coalescing the colocated buffers for all 3 while instructions,
2128 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2129 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2130 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2131
2132 // Builds a condition computation: x -> x < 4
2133 auto build_cond = [&]() {
2134 auto builder = HloComputation::Builder("cond");
2135 auto const4 = builder.AddInstruction(
2136 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2137 auto param =
2138 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2139 builder.AddInstruction(
2140 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2141 const4, ComparisonDirection::kLt));
2142 return builder.Build();
2143 };
2144
2145 // Builds a body computation: x -> x + 9
2146 auto build_body = [&]() {
2147 auto builder = HloComputation::Builder("body");
2148 auto const9 = builder.AddInstruction(
2149 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2150 auto param =
2151 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2152 builder.AddInstruction(
2153 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2154 return builder.Build();
2155 };
2156
2157 // Build the entry computation as described in the comment above.
2158 auto module = CreateNewVerifiedModule();
2159 auto builder = HloComputation::Builder("entry");
2160
2161 auto token = builder.AddInstruction(HloInstruction::CreateToken());
2162 auto infeed =
2163 builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2164 auto infeed_data = builder.AddInstruction(
2165 HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2166 auto cond0 = module->AddEmbeddedComputation(build_cond());
2167 auto body0 = module->AddEmbeddedComputation(build_body());
2168 auto while0 = builder.AddInstruction(
2169 HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2170
2171 auto cond1 = module->AddEmbeddedComputation(build_cond());
2172 auto body1 = module->AddEmbeddedComputation(build_body());
2173 auto while1 = builder.AddInstruction(
2174 HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2175
2176 auto zero = builder.AddInstruction(
2177 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2178 auto add = builder.AddInstruction(
2179 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2180 auto cond2 = module->AddEmbeddedComputation(build_cond());
2181 auto body2 = module->AddEmbeddedComputation(build_body());
2182 auto while2 = builder.AddInstruction(
2183 HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2184
2185 auto tuple =
2186 builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2187 module->AddEntryComputation(builder.Build());
2188
2189 // Run CopyInsertion and check if the graph constructed above doesn't need
2190 // any copies inserted for BufferAssignment to run.
2191 int64 instruction_count = module->instruction_count();
2192 CopyInsertion copy_insertion;
2193 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2194 ASSERT_EQ(instruction_count, module->instruction_count());
2195
2196 // Create a sequential order among all the instructions in the entry
2197 // computation, since the issue this test stresses depends on the order the
2198 // nodes are traversed during BufferAssignment.
2199 TF_ASSERT_OK_AND_ASSIGN(
2200 HloSchedule schedule,
2201 ScheduleModule(module.get(), [](const BufferValue& buffer) {
2202 return ShapeUtil::ByteSizeOf(buffer.shape(),
2203 /*pointer_size=*/sizeof(void*));
2204 }));
2205 schedule.set_sequence(
2206 module->entry_computation(),
2207 {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2208 TF_ASSERT_OK(schedule.Verify());
2209
2210 TF_ASSERT_OK_AND_ASSIGN(
2211 auto assignment,
2212 BufferAssigner::Run(
2213 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2214 backend().compiler()->BufferSizeBytesFunction(),
2215 [](LogicalBuffer::Color) { return 1; },
2216 /*allow_input_output_aliasing=*/false,
2217 /*allocate_buffers_for_constants=*/true));
2218
2219 // The result tuple elements must be assigned with different buffers.
2220 TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2221 TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2222 EXPECT_NE(slice0, slice1);
2223
2224 // while0 and while1 result buffers must be equal to slice1.
2225 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2226 assignment->GetUniqueSlice(while0, {}));
2227 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2228 assignment->GetUniqueSlice(while1, {}));
2229 EXPECT_EQ(slice1, slice_while0);
2230 EXPECT_EQ(slice1, slice_while1);
2231
2232 // while2 result buffer must be equal to slice0.
2233 TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2234 assignment->GetUniqueSlice(while2, {}));
2235 EXPECT_EQ(slice0, slice_while2);
2236 }
2237
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2238 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2239 auto module = CreateNewVerifiedModule();
2240 auto builder = HloComputation::Builder("entry");
2241
2242 auto input0 = builder.AddInstruction(
2243 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2244 auto weights0 = builder.AddInstruction(
2245 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2246
2247 auto zero = builder.AddInstruction(
2248 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2249 auto output0 = builder.AddInstruction(
2250 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2251
2252 auto cond0 =
2253 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2254 auto body0 =
2255 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2256
2257 auto tuple0 = builder.AddInstruction(
2258 HloInstruction::CreateTuple({input0, weights0, output0}));
2259 auto while0 = builder.AddInstruction(
2260 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2261
2262 auto cond1 =
2263 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2264 auto body1 =
2265 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2266
2267 auto while1 = builder.AddInstruction(
2268 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2269
2270 module->AddEntryComputation(builder.Build());
2271 RunCopyInsertion(module.get());
2272 auto assignment = RunBufferAssignment(module.get());
2273
2274 // while0 and while1 buffers should be completely aligned.
2275 EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2276 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2277 EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2278 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2279 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2280 assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2281 }
2282
TEST_F(BufferAssignmentTest,TwoCalls)2283 TEST_F(BufferAssignmentTest, TwoCalls) {
2284 auto module = CreateNewVerifiedModule();
2285 Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2286 HloComputation* sub_computation;
2287 {
2288 auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2289 auto param = builder.AddInstruction(
2290 HloInstruction::CreateParameter(0, r0f32, "param"));
2291 auto constant1 = builder.AddInstruction(
2292 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2293 auto add = builder.AddInstruction(
2294 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2295 sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2296 }
2297 auto builder = HloComputation::Builder(TestName());
2298 auto constant2 = builder.AddInstruction(
2299 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2300 auto constant3 = builder.AddInstruction(
2301 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2302 auto call1 = builder.AddInstruction(
2303 HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2304 auto call2 = builder.AddInstruction(
2305 HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2306 auto add1 = builder.AddInstruction(
2307 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2308 auto add2 = builder.AddInstruction(
2309 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2310 module->AddEntryComputation(builder.Build(add2));
2311
2312 {
2313 FlattenCallGraph flatten;
2314 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2315 EXPECT_TRUE(result);
2316 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2317 }
2318
2319 RunCopyInsertion(module.get());
2320 auto assignment = RunBufferAssignment(module.get());
2321
2322 EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2323 }
2324
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2325 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2326 const char* hlo_text = R"(
2327 HloModule CallParamCoAllocation
2328
2329 Callee {
2330 param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2331 param1 = s32[20] parameter(1)
2332 ROOT constant = f32[] constant(1)
2333 }
2334
2335 ENTRY Main {
2336 entry_param0 = f32[100] parameter(0)
2337 entry_param1 = s32[20] parameter(1)
2338 custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2339 call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2340 ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2341 }
2342 )";
2343
2344 HloModuleConfig config;
2345 config.set_debug_options(GetDebugOptionsFromFlags());
2346 TF_ASSERT_OK_AND_ASSIGN(auto m,
2347 ParseAndReturnVerifiedModule(hlo_text, config));
2348
2349 auto buffers = RunBufferAssignment(m.get());
2350
2351 HloComputation* main = m->entry_computation();
2352 HloComputation* callee = m->GetComputationWithName("Callee");
2353 EXPECT_NE(callee, nullptr);
2354
2355 HloInstruction* param0 = callee->parameter_instruction(0);
2356 HloInstruction* param1 = callee->parameter_instruction(1);
2357
2358 HloInstruction* entry_param0 = main->parameter_instruction(0);
2359 HloInstruction* entry_param1 = main->parameter_instruction(1);
2360 HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2361
2362 EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2363 GetAllocation(*buffers, param0, {0}));
2364 EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2365 GetAllocation(*buffers, param1, {}));
2366
2367 EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2368 GetAllocation(*buffers, param0, {1}));
2369 EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2370 GetAllocation(*buffers, param0, {1, 0}));
2371 EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2372 GetAllocation(*buffers, param0, {1, 1}));
2373 }
2374
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2375 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2376 auto module = CreateNewVerifiedModule();
2377 auto builder = HloComputation::Builder(TestName());
2378
2379 auto zero = builder.AddInstruction(
2380 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2381 auto one = builder.AddInstruction(
2382 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2383
2384 auto input0 = builder.AddInstruction(
2385 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2386 auto weights0 = builder.AddInstruction(
2387 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2388 auto output0 = builder.AddInstruction(
2389 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2390
2391 auto input1 = builder.AddInstruction(
2392 HloInstruction::CreateParameter(2, data_shape_, "input1"));
2393 auto weights1 = builder.AddInstruction(
2394 HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2395 auto output1 = builder.AddInstruction(
2396 HloInstruction::CreateBroadcast(data_shape_, one, {}));
2397
2398 auto cond =
2399 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2400 auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2401
2402 auto tuple0 = builder.AddInstruction(
2403 HloInstruction::CreateTuple({input0, weights0, output0}));
2404 auto tuple1 = builder.AddInstruction(
2405 HloInstruction::CreateTuple({input1, weights1, output1}));
2406
2407 auto while0 = builder.AddInstruction(
2408 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2409 auto while1 = builder.AddInstruction(
2410 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2411
2412 auto gte0 = builder.AddInstruction(
2413 HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2414 auto gte1 = builder.AddInstruction(
2415 HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2416 auto root_add = builder.AddInstruction(
2417 HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2418
2419 module->AddEntryComputation(builder.Build());
2420
2421 {
2422 FlattenCallGraph flatten;
2423 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2424 EXPECT_TRUE(result);
2425 }
2426
2427 RunCopyInsertion(module.get());
2428
2429 HloSchedule schedule =
2430 ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2431
2432 // To trigger b/38494731, we want a specific Hlo schedule for the
2433 // root computation, so we overwrite that entry with a manually
2434 // crafted sequence.
2435 schedule.set_sequence(
2436 module->entry_computation(),
2437 {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2438 input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2439 gte0, gte1, root_add});
2440
2441 // If this ASSERT fails, we constructed a bogus sequence above and this test
2442 // itself is buggy.
2443 TF_ASSERT_OK(schedule.Verify());
2444
2445 auto assignment =
2446 BufferAssigner::Run(
2447 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2448 ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2449 /*allow_input_output_aliasing=*/false,
2450 /*allocate_buffers_for_constants=*/true)
2451 .ConsumeValueOrDie();
2452
2453 EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2454 }
2455
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2456 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2457 auto module = CreateNewVerifiedModule();
2458 auto builder = HloComputation::Builder("entry");
2459
2460 auto input0 = builder.AddInstruction(
2461 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2462 auto weights0 = builder.AddInstruction(
2463 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2464
2465 auto zero = builder.AddInstruction(
2466 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2467 auto output0 = builder.AddInstruction(
2468 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2469 auto output1 = builder.AddInstruction(
2470 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2471
2472 auto cond0 =
2473 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2474 auto body0 =
2475 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2476
2477 auto tuple0 = builder.AddInstruction(
2478 HloInstruction::CreateTuple({input0, weights0, output0}));
2479 auto while0 = builder.AddInstruction(
2480 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2481
2482 // Get output of 'while0' and feed as input to 'while1'.
2483 auto while0_out = builder.AddInstruction(
2484 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2485
2486 auto cond1 =
2487 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2488 auto body1 =
2489 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2490
2491 auto tuple1 = builder.AddInstruction(
2492 HloInstruction::CreateTuple({while0_out, weights0, output1}));
2493 auto while1 = builder.AddInstruction(
2494 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2495
2496 // Get output of 'while1' so that it is live out of computation.
2497 auto while1_out = builder.AddInstruction(
2498 HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2499
2500 module->AddEntryComputation(builder.Build());
2501 RunCopyInsertion(module.get());
2502 auto assignment = RunBufferAssignment(module.get());
2503 // Get BufferAllocation for root instruction.
2504 auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2505 .ConsumeValueOrDie()
2506 .allocation();
2507 // Test that root instruction allocation is live out.
2508 EXPECT_TRUE(root_alloc->maybe_live_out());
2509 // Test that root instruction allocation is not an entry parameter.
2510 EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2511 }
2512
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2513 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2514 const char* const hlo_string = R"(
2515 HloModule test
2516
2517 while_body {
2518 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2519 constant.1 = f32[] constant(0)
2520 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2521 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2522 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2523 constant.2 = s32[] constant(128)
2524 add.5 = s32[] add(get-tuple-element.3, constant.2)
2525 constant.3 = s32[] constant(0)
2526 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2527 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2528 ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2529 }
2530
2531 while_condition {
2532 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2533 get-tuple-element = s32[] get-tuple-element(state), index=0
2534 get-tuple-element.1 = s32[] constant(3)
2535 ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2536 }
2537
2538 ENTRY entry_computation {
2539 constant.7 = s32[] constant(0)
2540 copy.1 = s32[] copy(constant.7)
2541 constant.6 = f32[] constant(0)
2542 broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2543 tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2544 while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2545 ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2546 }
2547
2548 )";
2549 auto module_or_status =
2550 HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
2551 auto module = module_or_status.ConsumeValueOrDie();
2552
2553 RunCopyInsertion(module.get());
2554 auto assignment = RunBufferAssignment(module.get());
2555 // Get BufferAllocation for root instruction.
2556 auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2557 auto dus9_alloc_slice =
2558 assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2559 auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2560 auto dus5_alloc_slice =
2561 assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2562 // Test that the two dynamic-update-slice ops share the same allocation slice.
2563 EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2564 EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2565 }
2566 } // namespace
2567 } // namespace xla
2568