1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/test.h"
42 #include "tensorflow/compiler/xla/test_helpers.h"
43 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/macros.h"
48
49 namespace xla {
50 namespace {
51
52 using ::testing::UnorderedElementsAre;
53
54 // DFS visitor that collects the instructions referenced by a computation
55 // without descending into nested computations, i.e., only from the operands.
56 class InstructionListVisitor : public DfsHloVisitorWithDefault {
57 public:
InstructionListVisitor(const HloInstruction * root)58 explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
59
DefaultAction(HloInstruction * hlo)60 Status DefaultAction(HloInstruction* hlo) override {
61 // For each instruction, just push it on the list after walking the
62 // operands.
63 instructions_.push_back(hlo);
64 VLOG(0) << "List instruction " << hlo->ToString();
65 return Status::OK();
66 }
67
GetInstructions()68 std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
69
70 private:
71 // The instruction root of the computation.
72 const HloInstruction* root_;
73
74 // The full set of instructions found (may be duplicates, e.g., kParameter).
75 std::vector<const HloInstruction*> instructions_;
76
77 TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
78 };
79
GetInstructions(HloInstruction * root)80 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
81 InstructionListVisitor main_list(root);
82 TF_CHECK_OK(root->Accept(&main_list));
83 return main_list.GetInstructions();
84 }
85
86 class BufferAssignmentTest : public HloTestBase {
87 protected:
~BufferAssignmentTest()88 ~BufferAssignmentTest() override {}
89
RunBufferAssignment(HloModule * module,int64 alignment=1)90 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
91 int64 alignment = 1) {
92 return BufferAssigner::Run(
93 module, absl::make_unique<DependencyHloOrdering>(module),
94 backend().compiler()->BufferSizeBytesFunction(),
95 [alignment](LogicalBuffer::Color) { return alignment; },
96 /*allocate_buffers_for_constants=*/true)
97 .ConsumeValueOrDie();
98 }
99
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64 alignment=1)100 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
101 HloModule* module, int64 alignment = 1) {
102 return BufferAssigner::Run(
103 module, absl::make_unique<DependencyHloOrdering>(module),
104 backend().compiler()->BufferSizeBytesFunction(),
105 [alignment](LogicalBuffer::Color) { return alignment; },
106 /*allocate_buffers_for_constants=*/false)
107 .ConsumeValueOrDie();
108 }
109
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64 alignment=1)110 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
111 HloModule* module, int64 alignment = 1) {
112 absl::flat_hash_set<HloOpcode> must_not_live_out = {HloOpcode::kAdd};
113
114 return BufferAssigner::Run(
115 module, absl::make_unique<DependencyHloOrdering>(module),
116 backend().compiler()->BufferSizeBytesFunction(),
117 [alignment](LogicalBuffer::Color) { return alignment; },
118 /*allocate_buffers_for_constants=*/false,
119 /*colorer=*/BufferAssigner::DefaultColorer(),
120 /*must_not_live_out=*/must_not_live_out)
121 .ConsumeValueOrDie();
122 }
123
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64 alignment=1)124 std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
125 HloModule* module, BufferAssigner::Colorer colorer, int64 alignment = 1) {
126 return BufferAssigner::Run(
127 module, absl::make_unique<DependencyHloOrdering>(module),
128 backend().compiler()->BufferSizeBytesFunction(),
129 [alignment](LogicalBuffer::Color) { return alignment; },
130 /*allocate_buffers_for_constants=*/true, std::move(colorer))
131 .ConsumeValueOrDie();
132 }
133
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64 alignment=1)134 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
135 HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
136 int64 alignment = 1) {
137 HloSchedule schedule(module);
138 schedule.set_sequence(module->entry_computation(), instruction_sequence);
139 return BufferAssigner::Run(
140 module, absl::make_unique<SequentialHloOrdering>(schedule),
141 backend().compiler()->BufferSizeBytesFunction(),
142 [alignment](LogicalBuffer::Color) { return alignment; },
143 /*allocate_buffers_for_constants=*/true)
144 .ConsumeValueOrDie();
145 }
146
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64 alignment=1)147 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
148 HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
149 int64 alignment = 1) {
150 return BufferAssigner::Run(
151 module, absl::make_unique<DependencyHloOrdering>(module),
152 backend().compiler()->BufferSizeBytesFunction(),
153 [alignment](LogicalBuffer::Color) { return alignment; },
154 /*allocate_buffers_for_constants=*/true,
155 BufferAssigner::DefaultColorer(),
156 /*must_not_live_out=*/{},
157 /*can_share_buffer=*/nullptr, std::move(preset_assignments))
158 .ConsumeValueOrDie();
159 }
160
161 // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)162 std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
163 auto builder = HloComputation::Builder(name);
164 auto param =
165 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
166 auto value = builder.AddInstruction(
167 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
168 builder.AddInstruction(
169 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
170 return builder.Build();
171 }
172
BuildReduceComputation(const string & name)173 std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
174 auto builder = HloComputation::Builder(name);
175 auto param =
176 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
177 auto param2 =
178 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
179 builder.AddInstruction(
180 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
181 return builder.Build();
182 }
183
184 // Builds a simple compare-to-limit (x < 4) computation for a While.
185 //
186 // condition:
187 // const4[s32] -----------------------------------\
188 // \
189 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
190 //
BuildWhileConditionComputation(const string & name)191 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
192 const string& name) {
193 auto builder = HloComputation::Builder(name);
194 auto const4 = builder.AddInstruction(
195 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
196 auto param = builder.AddInstruction(
197 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
198 auto index = builder.AddInstruction(
199 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
200 builder.AddInstruction(
201 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
202 const4, ComparisonDirection::kLt));
203 return builder.Build();
204 }
205
206 // Builds a simple body computation for a While.
207 //
208 // body:
209 // constv[f32[4]] --------------------------------------\
210 // \
211 // /--- get-tuple-elementv[1] --- addv ---\
212 // param[(s32,f32[4])] ---| tuple
213 // \--- get-tuple-elementc[0] --- addc ---/
214 // /
215 // const1[s32] -----------------------------------------/
216 //
BuildWhileBodyComputation(const string & name)217 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
218 const string& name) {
219 auto builder = HloComputation::Builder(name);
220 auto const1 = builder.AddInstruction(
221 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
222 auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
223 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
224 auto param = builder.AddInstruction(
225 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
226 auto indexc = builder.AddInstruction(
227 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
228 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
229 indexc->shape(), HloOpcode::kAdd, indexc, const1));
230 auto indexv = builder.AddInstruction(
231 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
232 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
233 constv->shape(), HloOpcode::kAdd, indexv, constv));
234 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
235 return builder.Build();
236 }
237
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)238 std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
239 HloOpcode opcode, const string& name) {
240 auto builder = HloComputation::Builder(name);
241 auto param =
242 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
243 builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
244 return builder.Build();
245 }
246
247 // Verifies that the given instruction hlo has a valid input buffer assigned,
248 // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)249 const BufferAllocation& GetAssignedInputAllocation(
250 const BufferAssignment& buffers, HloInstruction* hlo) {
251 LOG(INFO) << "Checking input: " << hlo->ToString();
252 const BufferAllocation& buffer =
253 *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
254 EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
255 return buffer;
256 }
257
258 // Verifies that the given instruction hlo has a valid output buffer
259 // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)260 const BufferAllocation& GetAssignedOutputAllocation(
261 const BufferAssignment& buffers, HloInstruction* hlo) {
262 LOG(INFO) << "Checking output: " << hlo->ToString();
263 const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
264 return buffer;
265 }
266
267 // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)268 const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
269 const HloInstruction* hlo,
270 const ShapeIndex& index) {
271 return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
272 }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)273 const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
274 const HloInstruction* hlo) {
275 return *buffers.GetUniqueTopLevelSlice(hlo)
276 .ConsumeValueOrDie()
277 .allocation();
278 }
279
280 // Verifies that all instructions in the given instruction list except
281 // kConstant have assigned buffers, and returns their total size. If min_index
282 // and max_index are not nullptr, the minimum and maximum buffer indices in
283 // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)284 int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
285 const BufferAssignment& buffers) {
286 // Verifies all instructions have buffers, and gets the index ranges.
287 for (const HloInstruction* hlo : instructions) {
288 if (!buffers.HasTopLevelAllocation(hlo)) {
289 // If `hlo` has no assigned buffer, it is either a constant or a nested
290 // parameter.
291 EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
292 HloOpcode::kParameter == hlo->opcode());
293 continue;
294 }
295 }
296
297 // Gets the total size of all buffers assigned.
298 int64 total_size = 0;
299 for (auto& allocation : buffers.Allocations()) {
300 total_size += allocation.size();
301 }
302 return total_size;
303 }
304
305 // Shapes for use in the examples.
306 Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
307 Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
308 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
309 Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
310 Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
311 Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
312 Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
313 Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
314 };
315
316 // Returns true if the buffers assigned to instructions in "a" are distinct
317 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)318 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
319 const std::vector<const HloInstruction*>& b,
320 const BufferAssignment& assignment) {
321 absl::flat_hash_set<BufferAllocation::Slice> a_slices;
322 for (const HloInstruction* instruction : a) {
323 if (assignment.HasTopLevelAllocation(instruction)) {
324 a_slices.insert(
325 assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
326 }
327 }
328
329 for (const HloInstruction* instruction : b) {
330 if (assignment.HasTopLevelAllocation(instruction)) {
331 if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
332 .ConsumeValueOrDie())) {
333 return false;
334 }
335 }
336 }
337 return true;
338 }
339
340 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)341 TEST_F(BufferAssignmentTest, ScalarConstant) {
342 auto builder = HloComputation::Builder(TestName());
343 auto const0 = builder.AddInstruction(
344 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
345 auto module = CreateNewVerifiedModule();
346 module->AddEntryComputation(builder.Build());
347
348 {
349 auto buffers = RunBufferAssignment(module.get());
350 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
351 }
352
353 {
354 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
355 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
356 }
357 }
358
TEST_F(BufferAssignmentTest,BufferForConst)359 TEST_F(BufferAssignmentTest, BufferForConst) {
360 // Addition of two vector constants: checks that internal constant nodes have
361 // no buffers assigned, and their consumer has a buffer.
362 auto builder = HloComputation::Builder(TestName());
363 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
364 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
365 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
366 LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
367 auto add = builder.AddInstruction(
368 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
369 auto module = CreateNewVerifiedModule();
370 module->AddEntryComputation(builder.Build());
371
372 {
373 auto buffers = RunBufferAssignment(module.get());
374 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
375 EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
376 GetAssignedOutputAllocation(*buffers, add);
377 }
378 {
379 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
380 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
381 EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
382 GetAssignedOutputAllocation(*buffers, add);
383 }
384 }
385
TEST_F(BufferAssignmentTest,HasAllocationAt)386 TEST_F(BufferAssignmentTest, HasAllocationAt) {
387 // Create a tuple with non-const and const elements and check that
388 // HasAllocationAt works correctly.
389 auto builder = HloComputation::Builder(TestName());
390 auto param0 = builder.AddInstruction(
391 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
392 auto constant = builder.AddInstruction(
393 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
394 auto negate = builder.AddInstruction(
395 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
396 auto tuple = builder.AddInstruction(
397 HloInstruction::CreateTuple({negate, param0, constant}));
398 auto module = CreateNewVerifiedModule();
399 module->AddEntryComputation(builder.Build());
400
401 auto buffers = RunBufferAssignment(module.get());
402 // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
403 // reports for the instruction directly.
404 EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
405 buffers->HasAllocationAt(tuple, /*index=*/{}));
406 EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
407 buffers->HasAllocationAt(tuple, /*index=*/{0}));
408 EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
409 buffers->HasAllocationAt(tuple, /*index=*/{1}));
410 EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
411 buffers->HasAllocationAt(tuple, /*index=*/{2}));
412 }
413
TEST_F(BufferAssignmentTest,BufferForOutputConst)414 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
415 // This computation copies a constant to output.
416 auto builder = HloComputation::Builder(TestName());
417 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
418 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
419 auto copy = builder.AddInstruction(
420 HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
421 auto module = CreateNewVerifiedModule();
422 module->AddEntryComputation(builder.Build());
423
424 auto buffers = RunBufferAssignment(module.get());
425 // The copy node now has an output buffer.
426 GetAssignedOutputAllocation(*buffers, copy);
427 }
428
TEST_F(BufferAssignmentTest,Basic)429 TEST_F(BufferAssignmentTest, Basic) {
430 // paramscalar ------- (mul) -- (add) -- (sub)
431 // / / /
432 // param0[100] -------/ / /
433 // / /
434 // param1[100] --------------/--------/
435 auto builder = HloComputation::Builder(TestName());
436 auto paramscalar =
437 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
438 auto broadcast = builder.AddInstruction(
439 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
440 auto param0 = builder.AddInstruction(
441 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
442 auto param1 = builder.AddInstruction(
443 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
444 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
445 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
446 auto add = builder.AddInstruction(
447 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
448 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
449 f32vec100_, HloOpcode::kSubtract, add, param1));
450 auto module = CreateNewVerifiedModule();
451 module->AddEntryComputation(builder.Build());
452
453 auto buffers = RunBufferAssignment(module.get());
454
455 // Distinct input buffers were assigned for parameters.
456 BufferAllocation paramscalar_buffer =
457 GetAssignedInputAllocation(*buffers, paramscalar);
458 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
459 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
460 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
461 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
462 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
463
464 // The mul node has a valid buffer assigned, doesn't share with input.
465 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
466 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
467
468 // The add node can reuse the mul node's buffer.
469 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
470 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
471
472 // The sub node has a valid output buffer assigned.
473 GetAssignedOutputAllocation(*buffers, sub);
474 }
475
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)476 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
477 // If an input buffer and output buffer aliases, the input buffer can be
478 // reused for other intermediate results.
479 //
480 // param0[100] ----- (neg1) -- (neg2)
481 // | |
482 // + -------- Aliased ---------+
483
484 auto builder = HloComputation::Builder(TestName());
485
486 auto param = builder.AddInstruction(
487 HloInstruction::CreateParameter(0, f32vec100_, "p0"));
488 auto neg_1 = builder.AddInstruction(
489 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
490 auto neg_2 = builder.AddInstruction(
491 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
492
493 auto module = CreateNewVerifiedModule();
494 module->AddEntryComputation(builder.Build());
495
496 TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
497
498 auto buffers = RunBufferAssignment(module.get());
499
500 BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
501 BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
502 BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
503
504 // Everything use one buffer.
505 EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
506 EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
507 }
508
TEST_F(BufferAssignmentTest,AddCannotReuse)509 TEST_F(BufferAssignmentTest, AddCannotReuse) {
510 // Pass in a special rule to indicate that "add" cannot be live out.
511 //
512 // paramscalar ------- (mul) -- (add) -- (sub)
513 // / / /
514 // param0[100] -------/ / /
515 // / /
516 // param1[100] --------------/--------/
517 auto builder = HloComputation::Builder(TestName());
518 auto paramscalar =
519 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
520 auto broadcast = builder.AddInstruction(
521 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
522 auto param0 = builder.AddInstruction(
523 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
524 auto param1 = builder.AddInstruction(
525 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
526 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
527 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
528 auto add = builder.AddInstruction(
529 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
530 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
531 f32vec100_, HloOpcode::kSubtract, add, param1));
532 auto module = CreateNewVerifiedModule();
533 module->AddEntryComputation(builder.Build());
534
535 auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
536
537 // Distinct input buffers were assigned for parameters.
538 BufferAllocation paramscalar_buffer =
539 GetAssignedInputAllocation(*buffers, paramscalar);
540 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
541 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
542 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
543 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
544 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
545
546 // The mul node has a valid buffer assigned, doesn't share with input.
547 const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
548 EXPECT_NE(sub_buffer.index(), param0_buffer.index());
549
550 // The add node cannot reuse the mul node's buffer since we told buffer
551 // assignment so.
552 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
553 EXPECT_NE(add_buffer.index(), sub_buffer.index());
554
555 // The sub node has a valid output buffer assigned.
556 GetAssignedOutputAllocation(*buffers, sub);
557 }
558
TEST_F(BufferAssignmentTest,BasicUniquelyColored)559 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
560 // paramscalar ------- (mul) -- (add) -- (sub)
561 // / / /
562 // param0[100] -------/ / /
563 // / /
564 // param1[100] --------------/--------/
565 // The output of each op is colored with a different color, so we can not
566 // share anything.
567 auto builder = HloComputation::Builder(TestName());
568 auto paramscalar =
569 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
570 auto broadcast = builder.AddInstruction(
571 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
572 auto param0 = builder.AddInstruction(
573 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
574 auto param1 = builder.AddInstruction(
575 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
576 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
577 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
578 auto add = builder.AddInstruction(
579 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
580 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
581 f32vec100_, HloOpcode::kSubtract, add, param1));
582 auto module = CreateNewVerifiedModule();
583 module->AddEntryComputation(builder.Build());
584
585 auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
586 int color = 0;
587 for (HloValue::Id id = 0;
588 id < alias_analysis->dataflow_analysis().values().size(); id++) {
589 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
590 value.set_color(BufferValue::Color(color++));
591 }
592 return Status::OK();
593 };
594
595 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
596
597 // Distinct input buffers were assigned for parameters.
598 BufferAllocation paramscalar_buffer =
599 GetAssignedInputAllocation(*buffers, paramscalar);
600 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
601 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
602 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
603 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
604 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
605
606 // The mul node has a valid buffer assigned, doesn't share with input.
607 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
608 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
609
610 // The add node can not reuse the mul node's buffer due to coloring.
611 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
612 EXPECT_NE(add_buffer.index(), mul_buffer.index());
613
614 // The sub node has a valid output buffer assigned.
615 GetAssignedOutputAllocation(*buffers, sub);
616
617 // Check if the HLO instructions have the correct colors in the layout.
618 EXPECT_EQ(param0->shape().layout().memory_space(), 2);
619 EXPECT_EQ(param1->shape().layout().memory_space(), 3);
620 EXPECT_EQ(mul->shape().layout().memory_space(), 4);
621 EXPECT_EQ(add->shape().layout().memory_space(), 5);
622 EXPECT_EQ(sub->shape().layout().memory_space(), 6);
623 }
624
TEST_F(BufferAssignmentTest,BasicPartiallyColored)625 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
626 // paramscalar ------- (mul) -- (add) -- (sub)
627 // / / /
628 // param0[100] -------/ / /
629 // / /
630 // param1[100] --------------/--------/
631 // The output of the mul and the add have the color 1, and the other buffers
632 // have the color 0, which allows the mul and add to share buffers.
633 auto builder = HloComputation::Builder(TestName());
634 auto paramscalar =
635 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
636 auto broadcast = builder.AddInstruction(
637 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
638 auto param0 = builder.AddInstruction(
639 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
640 auto param1 = builder.AddInstruction(
641 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
642 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
643 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
644 auto add = builder.AddInstruction(
645 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
646 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
647 f32vec100_, HloOpcode::kSubtract, add, param1));
648 auto module = CreateNewVerifiedModule();
649 module->AddEntryComputation(builder.Build());
650
651 auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
652 for (HloValue::Id id = 0;
653 id < alias_analysis->dataflow_analysis().values().size(); id++) {
654 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
655 auto& buffer = alias_analysis->GetBufferContainingValue(value);
656 for (const auto& alias : buffer.values()) {
657 if (alias->instruction()->opcode() == HloOpcode::kAdd ||
658 alias->instruction()->opcode() == HloOpcode::kMultiply) {
659 value.set_color(LogicalBuffer::Color(1));
660 }
661 }
662 if (!value.has_color()) {
663 value.set_color(LogicalBuffer::Color(0));
664 }
665 }
666 return Status::OK();
667 };
668
669 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
670
671 // Distinct input buffers were assigned for parameters.
672 BufferAllocation paramscalar_buffer =
673 GetAssignedInputAllocation(*buffers, paramscalar);
674 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
675 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
676 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
677 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
678 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
679
680 // The mul node has a valid buffer assigned, doesn't share with input.
681 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
682 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
683
684 // The add node can reuse the mul node's buffer.
685 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
686 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
687
688 // The sub node has a valid output buffer assigned.
689 GetAssignedOutputAllocation(*buffers, sub);
690
691 // Check if the HLO instructions have the correct colors in the layout.
692 EXPECT_EQ(mul->shape().layout().memory_space(), 1);
693 EXPECT_EQ(add->shape().layout().memory_space(), 1);
694 EXPECT_EQ(sub->shape().layout().memory_space(), 0);
695 EXPECT_EQ(param0->shape().layout().memory_space(), 0);
696 EXPECT_EQ(param1->shape().layout().memory_space(), 0);
697 }
698
TEST_F(BufferAssignmentTest,PresetAssignments)699 TEST_F(BufferAssignmentTest, PresetAssignments) {
700 // paramscalar ------- (mul) -- (add) -- (sub)
701 // / / /
702 // param0[100] -------/ / /
703 // / /
704 // param1[100] --------------/--------/
705 // Similar to BasicPartiallyColored, but the color is set in the layout.
706 // The output of the mul and the add have the color 1 and have preset
707 // assignments, and the other buffers have the color 0, which allows the mul
708 // and add to share buffers.
709 auto builder = HloComputation::Builder(TestName());
710 auto paramscalar =
711 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
712 auto broadcast = builder.AddInstruction(
713 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
714 auto param0 = builder.AddInstruction(
715 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
716 auto param1 = builder.AddInstruction(
717 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
718 Shape f32vec100_color1 =
719 ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1);
720 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
721 f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
722 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
723 f32vec100_color1, HloOpcode::kAdd, mul, param1));
724 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
725 f32vec100_, HloOpcode::kSubtract, add, param1));
726 auto module = CreateNewVerifiedModule();
727 module->AddEntryComputation(builder.Build());
728
729 auto preset_assignments = absl::make_unique<PresetAssignments>();
730 preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
731 preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
732 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
733 ->size = 950;
734
735 auto buffers = RunBufferAssignmentWithPresetAssignments(
736 module.get(), std::move(preset_assignments));
737
738 // Distinct input buffers were assigned for parameters.
739 BufferAllocation paramscalar_buffer =
740 GetAssignedInputAllocation(*buffers, paramscalar);
741 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
742 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
743 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
744 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
745 EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
746 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
747 EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
748
749 // The mul and add use the same preset buffer. Ensure it has the correct color
750 // and offsets.
751 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
752 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
753 EXPECT_EQ(mul_buffer, add_buffer);
754 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
755 EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
756
757 EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
758 for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
759 if (value_and_offsetsize.first->instruction() == mul) {
760 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
761 EXPECT_EQ(value_and_offsetsize.second.size, 400);
762 } else {
763 EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
764 EXPECT_EQ(value_and_offsetsize.second.offset, 550);
765 EXPECT_EQ(value_and_offsetsize.second.size, 400);
766 }
767 }
768
769 // The sub node has a valid output buffer assigned.
770 GetAssignedOutputAllocation(*buffers, sub);
771 }
772
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)773 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
774 // Tests preset assignments when there is no 1-to-1 correspondence between
775 // HloValue and HloBuffer (i.e., a while loop).
776 auto module = CreateNewVerifiedModule();
777 Shape f32vec10_color1 =
778 ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
779 Shape t_s32_f32v10_color1 =
780 ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
781
782 auto cond_builder = HloComputation::Builder("WhileCond");
783 HloInstruction* cond_param = cond_builder.AddInstruction(
784 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
785 HloInstruction* cond_iter = cond_builder.AddInstruction(
786 HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
787 HloInstruction* cond_limit = cond_builder.AddInstruction(
788 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
789 cond_builder.AddInstruction(
790 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
791 cond_limit, ComparisonDirection::kLt));
792 HloComputation* cond_computation =
793 module->AddEmbeddedComputation(cond_builder.Build());
794
795 auto body_builder = HloComputation::Builder("WhileBody");
796 HloInstruction* body_param = body_builder.AddInstruction(
797 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
798 HloInstruction* body_iter = body_builder.AddInstruction(
799 HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
800 HloInstruction* body_data = body_builder.AddInstruction(
801 HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
802 HloInstruction* body_data_increment = body_builder.AddInstruction(
803 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
804 {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
805 HloInstruction* body_data_next =
806 body_builder.AddInstruction(HloInstruction::CreateBinary(
807 f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
808 HloInstruction* body_iter_increment = body_builder.AddInstruction(
809 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
810 HloInstruction* body_iter_next =
811 body_builder.AddInstruction(HloInstruction::CreateBinary(
812 s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
813 body_builder.AddInstruction(
814 HloInstruction::CreateTuple({body_iter_next, body_data_next}));
815 HloComputation* body_computation =
816 module->AddEmbeddedComputation(body_builder.Build());
817
818 auto builder = HloComputation::Builder(TestName());
819 HloInstruction* iter = builder.AddInstruction(
820 HloInstruction::CreateParameter(0, s32_, "param_iter"));
821 HloInstruction* data = builder.AddInstruction(
822 HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
823 HloInstruction* negate = builder.AddInstruction(
824 HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
825 HloInstruction* tuple =
826 builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
827 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
828 t_s32_f32v10_color1, cond_computation, body_computation, tuple));
829 HloInstruction* while_data = builder.AddInstruction(
830 HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
831 builder.AddInstruction(HloInstruction::CreateBinary(
832 f32vec10_, HloOpcode::kAdd, while_data, data));
833 module->AddEntryComputation(builder.Build());
834
835 // Set only one preset assignment for while data and its aliases.
836 auto preset_assignments = absl::make_unique<PresetAssignments>();
837 preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
838 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
839 ->size = 140;
840
841 auto buffers = RunBufferAssignmentWithPresetAssignments(
842 module.get(), std::move(preset_assignments));
843
844 // All assigned buffers are aliased so they should have the same offset and
845 // size.
846 const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
847 EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
848 for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
849 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
850 EXPECT_EQ(value_and_offsetsize.second.size, 40);
851 EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
852 }
853 }
854
TEST_F(BufferAssignmentTest,MultipleUsersForNode)855 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
856 // This is similar to the Basic test, with the difference that (sub) is
857 // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
858 // (add)'s output.
859 //
860 // paramscalar -------\ /-----------\
861 // \ / \
862 // param0[100] ------- (mul) -- (add) -- (sub)
863 // /
864 // param1[100] ----------------/
865 //
866 auto builder = HloComputation::Builder(TestName());
867 auto paramscalar =
868 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
869 auto broadcast = builder.AddInstruction(
870 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
871 auto param0 = builder.AddInstruction(
872 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
873 auto param1 = builder.AddInstruction(
874 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
875 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
876 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
877 auto add = builder.AddInstruction(
878 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
879 auto sub = builder.AddInstruction(
880 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
881 auto module = CreateNewVerifiedModule();
882 module->AddEntryComputation(builder.Build());
883
884 auto buffers = RunBufferAssignment(module.get());
885
886 // Input buffers were assigned for parameters.
887 BufferAllocation paramscalar_buffer =
888 GetAssignedInputAllocation(*buffers, paramscalar);
889 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
890 BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
891 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
892 EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
893 EXPECT_NE(param0_buffer.index(), param1_index.index());
894
895 // The mul node had a buffer allocated.
896 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
897
898 // Now the add node can't reuse the mul node's buffer.
899 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
900 EXPECT_NE(add_buffer.index(), mul_buffer.index());
901
902 // Log size information for inspection.
903 const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
904 int64 size0 = ValidateBuffers(level0, *buffers);
905 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
906 << " for " << level0.size() << " instructions; "
907 << "total buffer size " << size0;
908 }
909
TEST_F(BufferAssignmentTest,TrivialMap)910 TEST_F(BufferAssignmentTest, TrivialMap) {
911 // This tests a trivial x+1 map as the only operation.
912 //
913 // param0[100x10] ---> (map x+1)
914 //
915 // Builds the map function.
916 auto module = CreateNewVerifiedModule();
917 auto map_computation =
918 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
919 auto inner_last = map_computation->root_instruction();
920
921 // Creates the main kernel and verifies instruction counts.
922 auto builder = HloComputation::Builder(TestName());
923 auto param0 = builder.AddInstruction(
924 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
925 auto map = builder.AddInstruction(
926 HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
927 module->AddEntryComputation(builder.Build());
928
929 const std::vector<const HloInstruction*> level0 = GetInstructions(map);
930 EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
931 const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
932 EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
933
934 // Assigns buffers and fetches sizes.
935 auto buffers = RunBufferAssignment(module.get());
936 int64 size0 = ValidateBuffers(level0, *buffers);
937 int64 size1 = ValidateBuffers(level1, *buffers);
938
939 // Both algorithms assign the map's buffer before processing the embedded
940 // computation, so we can verify that the buffers aren't shared between them
941 // by checking:
942 EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
943 << "Reuse between main kernel and embedded mapping.";
944
945 // An input buffer was assigned for the parameter.
946 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
947
948 // An output buffer was assigned for the map.
949 BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
950 EXPECT_NE(param0_buffer.index(), map_buffer.index());
951
952 // The final computation node of the map is an add of an f32 param and a
953 // constant.
954 EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
955 const BufferAllocation& inner_add_buffer =
956 GetTopLevelAllocation(*buffers, inner_last);
957 EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
958
959 // Log size information for inspection.
960 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
961 << " for " << level0.size() + level1.size() << " instructions; "
962 << "total buffer size " << size0 + size1;
963 }
964
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)965 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
966 // Make sure that the input buffer of a reduce cannot be reused for its
967 // output. (Reuse is not safe in the general case, as it reshapes and some
968 // out-of-order reductions could overwrite an element before a use.)
969 //
970 // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
971 auto module = CreateNewVerifiedModule();
972 auto reduce_computation =
973 module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
974
975 auto builder = HloComputation::Builder(TestName());
976 auto param0 = builder.AddInstruction(
977 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
978 auto exp1 = builder.AddInstruction(
979 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
980 auto exp2 = builder.AddInstruction(
981 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
982 auto const0 = builder.AddInstruction(
983 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
984 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
985 /*shape=*/f32vec10_,
986 /*operand=*/exp2,
987 /*init_value=*/const0,
988 /*dimensions_to_reduce=*/{0}, reduce_computation));
989 auto exp3 = builder.AddInstruction(
990 HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
991
992 module->AddEntryComputation(builder.Build());
993
994 auto buffers = RunBufferAssignment(module.get());
995 const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
996 ValidateBuffers(instrs, *buffers);
997
998 const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
999 const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1000 const BufferAllocation& reduce_buffer =
1001 GetTopLevelAllocation(*buffers, reduce);
1002
1003 // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1004 // checking.
1005 EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1006
1007 // The buffer of exp2 cannot be used for reduce, even though it's the only
1008 // operand.
1009 EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1010 }
1011
TEST_F(BufferAssignmentTest,ExampleWhile)1012 TEST_F(BufferAssignmentTest, ExampleWhile) {
1013 // This tests a While loop example from the ir_semantics document.
1014 //
1015 // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1016 // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1017 //
1018 // const3[s32] -------\
1019 // const4[f32[4]] --- tuple --- while[condition, body]
1020 //
1021 // Builds the nested condition and body.
1022 auto module = CreateNewVerifiedModule();
1023 auto condition_computation =
1024 module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1025 auto body_computation =
1026 module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1027
1028 // Creates the main kernel and verifies instruction counts.
1029 auto builder = HloComputation::Builder(TestName());
1030 auto const3 = builder.AddInstruction(
1031 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1032 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1033 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1034 auto tuple =
1035 builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1036 auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1037 t_s32_f32v4_, condition_computation, body_computation, tuple));
1038 module->AddEntryComputation(builder.Build());
1039
1040 const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1041 EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1042 const std::vector<const HloInstruction*> levelc =
1043 GetInstructions(condition_computation->root_instruction());
1044 EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1045 const std::vector<const HloInstruction*> levelb =
1046 GetInstructions(body_computation->root_instruction());
1047 EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1048
1049 // Assigns buffers and fetches sizes.
1050 auto buffers = RunBufferAssignment(module.get());
1051 int64 size0 = ValidateBuffers(level0, *buffers);
1052 int64 sizec = ValidateBuffers(levelc, *buffers);
1053 int64 sizeb = ValidateBuffers(levelb, *buffers);
1054
1055 // BufferAssignment will assign a single allocation for the following
1056 // instructions: while, while.cond.param, while.body.param, while.body.result.
1057 EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1058 << "Should be reuse between main kernel and embedded condition.";
1059 EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1060 << "Should be reuse between embedded condition and body.";
1061 // Expect buffer reuse between main kernel and body computation.
1062 EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1063 << "Should be reuse between main kernel and embedded body.";
1064
1065 // The final computation node of the while body is a tuple of s32 and
1066 // f32[4] adds.
1067 HloInstruction* body_root = body_computation->root_instruction();
1068 EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1069
1070 // Check that buffer for each subshape of 'while_op' shares allocation with
1071 // corresponding buffer from while body computation at same index.
1072 ShapeUtil::ForEachSubshape(
1073 while_op->shape(),
1074 [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1075 const ShapeIndex& index) {
1076 auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1077 auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1078 EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1079 });
1080
1081 // Log size information for inspection.
1082 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1083 << " for " << level0.size() + levelc.size() + levelb.size()
1084 << " instructions; total buffer size " << size0 + sizec + sizeb;
1085 }
1086
TEST_F(BufferAssignmentTest,ExampleConditional)1087 TEST_F(BufferAssignmentTest, ExampleConditional) {
1088 auto module = CreateNewVerifiedModule();
1089 auto true_computation = module->AddEmbeddedComputation(
1090 BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1091 auto false_computation = module->AddEmbeddedComputation(
1092 BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1093
1094 auto builder = HloComputation::Builder(TestName());
1095 auto pred = builder.AddInstruction(
1096 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1097 auto const1 = builder.AddInstruction(
1098 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1099 auto const2 = builder.AddInstruction(
1100 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1101 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1102 r0f32_, pred, const1, true_computation, const2, false_computation));
1103 module->AddEntryComputation(builder.Build());
1104
1105 const std::vector<const HloInstruction*> conditional_instrs =
1106 GetInstructions(conditional);
1107 const std::vector<const HloInstruction*> true_instrs =
1108 GetInstructions(true_computation->root_instruction());
1109 const std::vector<const HloInstruction*> false_instrs =
1110 GetInstructions(false_computation->root_instruction());
1111 EXPECT_EQ(4, conditional_instrs.size());
1112 EXPECT_EQ(2, true_instrs.size());
1113 EXPECT_EQ(2, false_instrs.size());
1114
1115 auto buffers = RunBufferAssignment(module.get());
1116 ValidateBuffers(conditional_instrs, *buffers);
1117 ValidateBuffers(true_instrs, *buffers);
1118 ValidateBuffers(false_instrs, *buffers);
1119
1120 EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1121 << "Should be reuse between conditional and true computation.";
1122 EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1123 << "Should be reuse between conditional and false computation.";
1124 EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1125 << "Should be reuse between true and false computations.";
1126
1127 const BufferAllocation& conditional_buffer =
1128 GetTopLevelAllocation(*buffers, conditional);
1129 const BufferAllocation& true_buffer =
1130 GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1131 const BufferAllocation& false_buffer =
1132 GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1133 EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1134 EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1135 }
1136
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1137 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1138 // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1139 auto builder = HloComputation::Builder(TestName());
1140 auto param0 = builder.AddInstruction(
1141 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1142 auto exp1 = builder.AddInstruction(
1143 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1144 auto tanh = builder.AddInstruction(
1145 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1146 auto exp2 = builder.AddInstruction(
1147 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1148 auto neg = builder.AddInstruction(
1149 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1150
1151 auto module = CreateNewVerifiedModule();
1152 module->AddEntryComputation(builder.Build());
1153 auto assignment = RunBufferAssignment(module.get());
1154
1155 // tanh and exp2 can reuse exp1's buffer
1156 EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1157 auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1158 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1159 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1160 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1161 }
1162
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1163 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1164 // This computation is a chain of operations which decreases in buffer size
1165 // (via slice) then increases in size (via broadcast):
1166 //
1167 // param ---> (negate) ---> (slice) ---> (broadcast)
1168 //
1169 // The negate should share a buffer with broadcast.
1170 auto builder = HloComputation::Builder(TestName());
1171 auto param0 = builder.AddInstruction(
1172 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1173 auto negate = builder.AddInstruction(
1174 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1175 auto slice = builder.AddInstruction(
1176 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1177 auto broadcast = builder.AddInstruction(
1178 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1179
1180 auto module = CreateNewVerifiedModule();
1181 module->AddEntryComputation(builder.Build());
1182 auto assignment = RunBufferAssignment(module.get());
1183
1184 // negate and broadcast should share a buffer.
1185 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1186 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1187 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1188
1189 // Slice should have its own buffer.
1190 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1191 }
1192
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1193 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1194 // This computation is identical to that in ReuseNonOperandBuffer, but the
1195 // negate value is live until the end of the computation (due to it being an
1196 // operand of the output tuple) preventing reuse.
1197 //
1198 // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1199 // \-----------------------------------/
1200 //
1201 // The negate should not share a buffer with broadcast.
1202 auto builder = HloComputation::Builder(TestName());
1203 auto param0 = builder.AddInstruction(
1204 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1205 auto negate = builder.AddInstruction(
1206 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1207 auto slice = builder.AddInstruction(
1208 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1209 auto broadcast = builder.AddInstruction(
1210 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1211 builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1212
1213 auto module = CreateNewVerifiedModule();
1214 module->AddEntryComputation(builder.Build());
1215 auto assignment = RunBufferAssignment(module.get());
1216
1217 // The instructions should not share buffers.
1218 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1219 GetTopLevelAllocation(*assignment, negate));
1220 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1221 GetTopLevelAllocation(*assignment, slice));
1222 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1223 GetTopLevelAllocation(*assignment, slice));
1224 }
1225
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1226 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1227 // This computation is identical to that in ReuseNonOperandBuffer, but the
1228 // negate value is placed into a tuple which lives to the end of the
1229 // computation. This extends the live range of negate's buffer preventing
1230 // reuse due to buffer aliasing.
1231 //
1232 // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1233 // \-----------------------------------/
1234 //
1235 // The negate should not share a buffer with broadcast.
1236 auto builder = HloComputation::Builder(TestName());
1237 auto param0 = builder.AddInstruction(
1238 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1239 auto negate = builder.AddInstruction(
1240 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1241 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1242 auto tuple_element = builder.AddInstruction(
1243 HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1244 auto slice = builder.AddInstruction(
1245 HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1246 auto broadcast = builder.AddInstruction(
1247 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1248 builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1249
1250 auto module = CreateNewVerifiedModule();
1251 module->AddEntryComputation(builder.Build());
1252 auto assignment = RunBufferAssignment(module.get());
1253
1254 // The instructions should not share buffers.
1255 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1256 GetTopLevelAllocation(*assignment, negate));
1257 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1258 GetTopLevelAllocation(*assignment, slice));
1259 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1260 GetTopLevelAllocation(*assignment, slice));
1261 }
1262
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1263 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1264 // This computation is very similar to ReuseNonOperandBuffer except the
1265 // broadcast has a smaller output than the negate. This should block reuse of
1266 // negate's buffer by broadcast because the output buffer(s) of a computation
1267 // should be exactly sized for the value.
1268 //
1269 // param ---> (negate) ---> (slice) ---> (broadcast)
1270 //
1271 // Neither negate nor slice may share a buffer with broadcast.
1272 auto builder = HloComputation::Builder(TestName());
1273 auto param0 = builder.AddInstruction(
1274 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1275 // Negate output is 100 elements.
1276 auto negate = builder.AddInstruction(
1277 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1278 // Slice output is 10 elements.
1279 auto slice = builder.AddInstruction(
1280 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1281 // Broadcast output is 40 elements.
1282 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1283 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1284
1285 auto module = CreateNewVerifiedModule();
1286 module->AddEntryComputation(builder.Build());
1287 auto assignment = RunBufferAssignment(module.get());
1288
1289 // The broadcast output buffer cannot be shared.
1290 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1291 GetTopLevelAllocation(*assignment, negate));
1292 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1293 GetTopLevelAllocation(*assignment, slice));
1294 }
1295
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1296 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1297 // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1298 // output is exactly the same size as the negate (rather than being
1299 // smaller). This enables reuse of negate's buffer by the broadcast because
1300 // the output buffer will be sized exactly to its value.
1301 //
1302 // param ---> (negate) ---> (slice) ---> (broadcast)
1303 //
1304 // The negate should *not* share a buffer with broadcast.
1305 auto builder = HloComputation::Builder(TestName());
1306 auto param0 = builder.AddInstruction(
1307 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1308 // Negate output is 100 elements.
1309 auto negate = builder.AddInstruction(
1310 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1311 auto slice = builder.AddInstruction(
1312 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1313 // Broadcast output is 40 elements.
1314 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1315 ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1316
1317 auto module = CreateNewVerifiedModule();
1318 module->AddEntryComputation(builder.Build());
1319 auto assignment = RunBufferAssignment(module.get());
1320
1321 // negate and broadcast should share a buffer.
1322 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1323 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1324 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1325
1326 // Slice should have its own buffer.
1327 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1328 }
1329
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1330 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1331 // This computation is very similar to ReuseNonOperandBuffer except the
1332 // broadcast has a smaller output than the negate, and the broadcast is
1333 // contained in the computation output as a tuple element. This should block
1334 // reuse of the negate's buffer by the broadcast because the output buffer(s)
1335 // of a computation should be exactly sized for the value. This includes those
1336 // buffers aliased in the output (eg, contained as tuple elements).
1337 //
1338 // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1339 //
1340 // Neither negate nor slice may share a buffer with broadcast.
1341 auto builder = HloComputation::Builder(TestName());
1342 auto param0 = builder.AddInstruction(
1343 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1344 // Negate output is 100 elements.
1345 auto negate = builder.AddInstruction(
1346 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1347 // Slice output is 10 elements.
1348 auto slice = builder.AddInstruction(
1349 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1350 // Broadcast output is 40 elements.
1351 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1352 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1353 builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1354
1355 auto module = CreateNewVerifiedModule();
1356 module->AddEntryComputation(builder.Build());
1357 auto assignment = RunBufferAssignment(module.get());
1358
1359 // The broadcast output buffer cannot be shared.
1360 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1361 GetTopLevelAllocation(*assignment, negate));
1362 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1363 GetTopLevelAllocation(*assignment, slice));
1364 }
1365
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1366 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1367 // Verify that buffers for embedded computations are properly marked as
1368 // thread-local and that embedded parameters are not marked as
1369 // is_entry_computation_parameter.
1370 auto module = CreateNewVerifiedModule();
1371 auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1372 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1373
1374 // Create a scalar computation to use in a map.
1375 auto map_builder = HloComputation::Builder(TestName() + "_map");
1376 auto map_param = map_builder.AddInstruction(
1377 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1378 auto map_root = map_builder.AddInstruction(
1379 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1380 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1381
1382 // Create a vector computation to use in a kCall.
1383 auto call_builder = HloComputation::Builder(TestName() + "_call");
1384 auto call_param = call_builder.AddInstruction(
1385 HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1386 auto call_root = call_builder.AddInstruction(
1387 HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1388 auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1389
1390 // Create entry computation which kCalls call_computation and then calls map
1391 // with map_computation on the result.
1392 auto builder = HloComputation::Builder(TestName());
1393 auto param = builder.AddInstruction(
1394 HloInstruction::CreateParameter(0, vec_shape, "param"));
1395 auto call = builder.AddInstruction(
1396 HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1397 auto map = builder.AddInstruction(
1398 HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1399 module->AddEntryComputation(builder.Build());
1400
1401 auto assignment = RunBufferAssignment(module.get());
1402
1403 // Allocations for the map computation should be thread-local and not
1404 // live-out.
1405 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1406 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1407 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1408 EXPECT_TRUE(map_param_alloc.is_thread_local());
1409
1410 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1411 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1412 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1413 EXPECT_TRUE(map_root_alloc.is_thread_local());
1414
1415 // Allocations for the call computation should not be thread-local.
1416 auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1417 EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1418 EXPECT_FALSE(call_param_alloc.maybe_live_out());
1419 EXPECT_FALSE(call_param_alloc.is_thread_local());
1420
1421 auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1422 EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1423 EXPECT_FALSE(call_root_alloc.is_thread_local());
1424
1425 // Entry computation allocations can be marked liveout and
1426 // is_entry_computation_parameter.
1427 auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1428 EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1429 EXPECT_FALSE(param_alloc.maybe_live_out());
1430 EXPECT_FALSE(param_alloc.is_thread_local());
1431
1432 auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1433 EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1434 EXPECT_TRUE(map_alloc.maybe_live_out());
1435 EXPECT_FALSE(map_alloc.is_thread_local());
1436 }
1437
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1438 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1439 // Test a computation that returns a tuple parameter.
1440 auto builder = HloComputation::Builder(TestName());
1441 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1442 0,
1443 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1444 ShapeUtil::MakeShape(F32, {}),
1445 ShapeUtil::MakeShape(S32, {42})}),
1446 "param0"));
1447
1448 auto module = CreateNewVerifiedModule();
1449 module->AddEntryComputation(builder.Build());
1450 auto assignment = RunBufferAssignment(module.get());
1451
1452 // There should be four allocations: one for vector of pointers, and one for
1453 // each tuple element.
1454 EXPECT_EQ(4, assignment->Allocations().size());
1455
1456 // Verify each buffer allocation is marked as an entry computation parameter
1457 // and is liveout.
1458 ShapeUtil::ForEachSubshape(
1459 tuple_param->shape(),
1460 [this, &assignment, tuple_param](const Shape& /*subshape*/,
1461 const ShapeIndex& index) {
1462 auto allocation = GetAllocation(*assignment, tuple_param, index);
1463 EXPECT_TRUE(allocation.is_entry_computation_parameter());
1464 EXPECT_EQ(0, allocation.parameter_number());
1465 EXPECT_TRUE(allocation.maybe_live_out());
1466 });
1467 }
1468
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1469 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1470 // Test a computation which returns a GetElementTuple of a nested tuple
1471 // parameter.
1472 auto builder = HloComputation::Builder(TestName());
1473 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1474 0,
1475 ShapeUtil::MakeTupleShape(
1476 {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1477 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1478 ShapeUtil::MakeShape(S32, {101})})}),
1479 "param0"));
1480 auto tuple_element =
1481 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1482 ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1483
1484 auto module = CreateNewVerifiedModule();
1485 module->AddEntryComputation(builder.Build());
1486 auto assignment = RunBufferAssignment(module.get());
1487
1488 // Only some of the elements of the input param are liveout.
1489 EXPECT_FALSE(
1490 GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1491 // Tuple element at index={1} is live out because GetTupleElement({1})
1492 // forwards a pointer to this allocation (instead of defining its own buffer).
1493 EXPECT_TRUE(
1494 GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1495 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1496 .maybe_live_out());
1497 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1498 .maybe_live_out());
1499
1500 // The GetTupleElement output is liveout.
1501 EXPECT_TRUE(
1502 GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1503
1504 // Verify that the GetTupleElement allocations of its elements match the
1505 // corresponding tuple parameter allocations because they alias.
1506 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1507 GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1508 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1509 GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1510
1511 // GetTupleElement forwards a pointer to its underlying buffer, so verify
1512 // that it has the same allocation than the corresponding parameter element.
1513 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1514 GetTopLevelAllocation(*assignment, tuple_element));
1515 }
1516
1517 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1518 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1519 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1520 // Test that a tuple constant which is forwarded to the computation output
1521 // is properly handled.
1522 auto builder = HloComputation::Builder(TestName());
1523 Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1524 LiteralUtil::CreateR0<int64>(1)};
1525 builder.AddInstruction(HloInstruction::CreateConstant(
1526 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1527
1528 auto module = CreateNewVerifiedModule();
1529 module->AddEntryComputation(builder.Build());
1530 auto assignment = RunBufferAssignment(module.get());
1531
1532 EXPECT_EQ(3, assignment->Allocations().size());
1533 }
1534
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1535 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1536 // Test a computation which returns a tuple custom call value.
1537 auto builder = HloComputation::Builder(TestName());
1538 auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1539 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1540 ShapeUtil::MakeShape(S32, {101})}),
1541 /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1542 auto module = CreateNewVerifiedModule();
1543 module->AddEntryComputation(builder.Build());
1544 auto assignment = RunBufferAssignment(module.get());
1545
1546 EXPECT_EQ(3, assignment->Allocations().size());
1547 EXPECT_TRUE(
1548 GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1549 EXPECT_TRUE(
1550 GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1551 EXPECT_TRUE(
1552 GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1553 }
1554
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1555 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1556 // Test a computation which returns a tuple call value.
1557 auto module = CreateNewVerifiedModule();
1558 auto elem_shape = f32vec4_;
1559 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1560
1561 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1562 auto sub_param = sub_builder.AddInstruction(
1563 HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1564 auto sub_tuple =
1565 sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1566 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1567
1568 auto builder = HloComputation::Builder(TestName());
1569 auto param = builder.AddInstruction(
1570 HloInstruction::CreateParameter(0, elem_shape, "param"));
1571 auto call = builder.AddInstruction(
1572 HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1573 module->AddEntryComputation(builder.Build());
1574
1575 auto assignment = RunBufferAssignment(module.get());
1576
1577 EXPECT_EQ(2, assignment->Allocations().size());
1578 // Buffers for call are colocated with the sub-computation.
1579 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1580 GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1581 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1582 GetAllocation(*assignment, sub_param, /*index=*/{}));
1583
1584 // The parameter isn't aliased with the result tuple, but it is aliased with
1585 // the call operand.
1586 EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1587 GetTopLevelAllocation(*assignment, sub_tuple));
1588 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1589 GetTopLevelAllocation(*assignment, sub_param));
1590 }
1591
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1592 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1593 // Test a chain of calls with tuple output. The chain looks like:
1594 // A: call(B, tuple(param))
1595 // B: call(C, param)
1596 // C: call(D, param)
1597 // D: param
1598 auto module = CreateNewVerifiedModule();
1599 auto elem_shape = f32vec4_;
1600 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1601
1602 auto d_builder = HloComputation::Builder(TestName() + "_d");
1603 auto d_param = d_builder.AddInstruction(
1604 HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1605 auto d_computation = d_builder.Build();
1606
1607 auto c_builder = HloComputation::Builder(TestName() + "_c");
1608 auto c_param = c_builder.AddInstruction(
1609 HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1610 auto c_call = c_builder.AddInstruction(
1611 HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1612 auto c_computation = c_builder.Build();
1613
1614 auto b_builder = HloComputation::Builder(TestName() + "_b");
1615 auto b_param = b_builder.AddInstruction(
1616 HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1617 auto b_call = b_builder.AddInstruction(
1618 HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1619 auto b_computation = b_builder.Build();
1620
1621 auto a_builder = HloComputation::Builder(TestName());
1622 auto a_param = a_builder.AddInstruction(
1623 HloInstruction::CreateParameter(0, elem_shape, "param"));
1624 auto a_tuple =
1625 a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1626 auto a_call = a_builder.AddInstruction(
1627 HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1628 auto a_computation = a_builder.Build();
1629
1630 // Add the computations in an order that doesn't match the dependency
1631 // post-order, to shake out more possible bugs.
1632 module->AddEmbeddedComputation(std::move(d_computation));
1633 module->AddEmbeddedComputation(std::move(c_computation));
1634 module->AddEntryComputation(std::move(a_computation));
1635 module->AddEmbeddedComputation(std::move(b_computation));
1636
1637 auto assignment = RunBufferAssignment(module.get());
1638
1639 // Buffers for call are colocated with the sub-computations.
1640 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1641 GetAllocation(*assignment, b_call, /*index=*/{}));
1642 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1643 GetAllocation(*assignment, c_call, /*index=*/{}));
1644 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1645 GetAllocation(*assignment, d_param, /*index=*/{}));
1646 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1647 GetAllocation(*assignment, b_call, /*index=*/{0}));
1648 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1649 GetAllocation(*assignment, c_call, /*index=*/{0}));
1650 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1651 GetAllocation(*assignment, d_param, /*index=*/{0}));
1652
1653 EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1654 EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1655 EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1656
1657 EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1658 GetAllocation(*assignment, c_param, /*index=*/{0}));
1659 EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1660 GetAllocation(*assignment, d_param, /*index=*/{0}));
1661 }
1662
TEST_F(BufferAssignmentTest,BitcastAsOutput)1663 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1664 // Test a computation which returns a bitcast value.
1665 auto builder = HloComputation::Builder(TestName());
1666 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1667 0, ShapeUtil::MakeShape(F32, {42}), "param"));
1668 auto bitcast = builder.AddInstruction(
1669 HloInstruction::CreateBitcast(param->shape(), param));
1670
1671 auto module = CreateNewVerifiedModule();
1672 module->AddEntryComputation(builder.Build());
1673 auto assignment = RunBufferAssignment(module.get());
1674
1675 // Bitcast should get the same allocation as the param.
1676 EXPECT_EQ(1, assignment->Allocations().size());
1677 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1678 GetTopLevelAllocation(*assignment, bitcast));
1679 }
1680
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1681 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1682 // Test a computation with an output that has an ambiguous points-to set.
1683 // This is constructed using a select among tuple shapes.
1684 auto builder = HloComputation::Builder(TestName());
1685 auto tuple_shape =
1686 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1687
1688 auto tuple_param0 = builder.AddInstruction(
1689 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1690 auto tuple_param1 = builder.AddInstruction(
1691 HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1692 auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1693 2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1694 auto select = builder.AddInstruction(
1695 HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1696 pred_param, tuple_param0, tuple_param1));
1697
1698 auto module = CreateNewVerifiedModule();
1699 module->AddEntryComputation(builder.Build());
1700 auto assignment = RunBufferAssignment(module.get());
1701
1702 // Select shallow copies one of its operands so it defines its own top-level
1703 // buffer and receives its own allocation.
1704 auto select_alloc = GetTopLevelAllocation(*assignment, select);
1705 EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1706 EXPECT_EQ(select,
1707 select_alloc.assigned_buffers().begin()->first->instruction());
1708
1709 // The buffer for the tuple element of the select is forwarded from one its
1710 // operands which cannot be determined statically. Therefore its slices
1711 // should include the slices of both of the elements in the parameters.
1712 auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1713 EXPECT_EQ(2, element_slices.size());
1714 EXPECT_THAT(element_slices,
1715 UnorderedElementsAre(
1716 assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1717 .ConsumeValueOrDie(),
1718 assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1719 .ConsumeValueOrDie()));
1720 }
1721
1722 // TODO(b/34669761): Remove this test when buffers are allowed to share
1723 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1724 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1725 // Test a computation that returns a tuple parameter.
1726 auto builder = HloComputation::Builder(TestName());
1727 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1728 auto param = builder.AddInstruction(
1729 HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1730 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1731 auto tuple_element = builder.AddInstruction(
1732 HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1733 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1734 scalar_shape, HloOpcode::kCopy, tuple_element));
1735
1736 auto module = CreateNewVerifiedModule();
1737 module->AddEntryComputation(builder.Build());
1738 auto assignment = RunBufferAssignment(module.get());
1739
1740 // There should be no buffer reuse. The copy should not reuse the tuple
1741 // buffer.
1742 EXPECT_EQ(3, assignment->Allocations().size());
1743 EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1744 GetTopLevelAllocation(*assignment, copy));
1745 }
1746
TEST_F(BufferAssignmentTest,OneTempAllocation)1747 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1748 // Test a computation that requires multiple temp buffers, and ensure they
1749 // are combined into a single allocation.
1750 auto builder = HloComputation::Builder(TestName());
1751 Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1752 Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1753 Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1754 Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1755 Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1756
1757 // There should be separate temp buffers for dot_ab and dot_bc.
1758 auto param_a = builder.AddInstruction(
1759 HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1760 auto param_b = builder.AddInstruction(
1761 HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1762 auto param_c = builder.AddInstruction(
1763 HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1764 DotDimensionNumbers dot_dnums;
1765 dot_dnums.add_lhs_contracting_dimensions(1);
1766 dot_dnums.add_rhs_contracting_dimensions(0);
1767 PrecisionConfig precision_config;
1768 precision_config.mutable_operand_precision()->Resize(
1769 2, PrecisionConfig::DEFAULT);
1770 auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1771 shape_2x4, param_a, param_b, dot_dnums, precision_config));
1772 auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1773 shape_3x4, param_b, param_c, dot_dnums, precision_config));
1774 builder.AddInstruction(
1775 HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1776
1777 // Run buffer assignment with alignment=1.
1778 auto module = CreateNewVerifiedModule();
1779 module->AddEntryComputation(builder.Build());
1780 auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1781
1782 // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1783 EXPECT_EQ(5, assignment->Allocations().size());
1784
1785 // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1786 // and each occupies different slices of that allocation.
1787 BufferAllocation::Slice slice_ab =
1788 assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1789 BufferAllocation::Slice slice_bc =
1790 assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1791 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1792 EXPECT_NE(slice_ab, slice_bc);
1793 EXPECT_EQ(32, slice_ab.size());
1794 EXPECT_EQ(48, slice_bc.size());
1795 EXPECT_EQ(80, slice_ab.allocation()->size());
1796 EXPECT_EQ(80, slice_bc.allocation()->size());
1797
1798 // Re-run buffer assignment with alignment=64.
1799 assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1800 EXPECT_EQ(5, assignment->Allocations().size());
1801 slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1802 slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1803 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1804 EXPECT_NE(slice_ab, slice_bc);
1805 EXPECT_EQ(32, slice_ab.size());
1806 EXPECT_EQ(48, slice_bc.size());
1807 // Ensure the offsets and allocation size account for the alignment, without
1808 // assuming which buffer gets assigned first.
1809 if (slice_ab.offset() == 0) {
1810 EXPECT_EQ(64, slice_bc.offset());
1811 EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1812 EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1813 } else {
1814 EXPECT_EQ(64, slice_ab.offset());
1815 EXPECT_EQ(0, slice_bc.offset());
1816 EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1817 EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1818 }
1819 }
1820
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1821 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1822 // paramscalar -(bc)- (mul) -- (add) -- (sub)
1823 // / / /
1824 // param0[100] -------/ / /
1825 // / /
1826 // param1[100] --------------/--------/
1827 auto builder = HloComputation::Builder(TestName());
1828 auto paramscalar =
1829 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1830 auto broadcast = builder.AddInstruction(
1831 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1832 auto param0 = builder.AddInstruction(
1833 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1834 auto param1 = builder.AddInstruction(
1835 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1836 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1837 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1838 auto add = builder.AddInstruction(
1839 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1840 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1841 f32vec100_, HloOpcode::kSubtract, add, param1));
1842 auto module = CreateNewVerifiedModule();
1843 module->AddEntryComputation(builder.Build());
1844
1845 auto buffers = RunBufferAssignment(module.get());
1846
1847 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1848 const std::vector<const HloValue*>& peak_buffers =
1849 mul_buffer.PeakMemoryLogicalBuffers();
1850 ASSERT_EQ(peak_buffers.size(), 1);
1851 EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1852 }
1853
TEST_F(BufferAssignmentTest,PeakBuffers)1854 TEST_F(BufferAssignmentTest, PeakBuffers) {
1855 // Compute the peak liveness buffers of the following sequence:
1856 //
1857 // %param = ...
1858 // %log = log(%param)
1859 // %rev = reverse(%log)
1860 // %neg = neg(%param)
1861 // %concat = concat(%rev, %neg)
1862 // ROOT %root = slice(concat)
1863 //
1864 // In the temporary block, the set of live buffers at peak memory use should
1865 // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1866 auto builder = HloComputation::Builder(TestName());
1867 auto param = builder.AddInstruction(
1868 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1869 auto log = builder.AddInstruction(
1870 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1871 auto rev = builder.AddInstruction(
1872 HloInstruction::CreateReverse(f32vec100_, log, {0}));
1873 auto neg = builder.AddInstruction(
1874 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1875 const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1876 auto concat = builder.AddInstruction(
1877 HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1878 // Make the root tiny so no interior nodes can share its buffer.
1879 auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1880
1881 ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1882
1883 auto module = CreateNewVerifiedModule();
1884 module->AddEntryComputation(builder.Build());
1885
1886 auto buffers = RunBufferAssignmentWithInstructionSequence(
1887 module.get(), {param, log, rev, neg, concat, root});
1888
1889 // The temporary buffer should hold the 4 interior instructions.
1890 const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1891 EXPECT_FALSE(buffer.IsInputOrOutput());
1892 EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1893 ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1894
1895 const std::vector<const HloValue*>& peak_buffers =
1896 buffer.PeakMemoryLogicalBuffers();
1897
1898 // The peak live set should be concat and its inputs.
1899 ASSERT_EQ(peak_buffers.size(), 3);
1900 std::vector<const HloInstruction*> peak_instructions;
1901 for (const HloValue* logical_buffer : peak_buffers) {
1902 peak_instructions.push_back(logical_buffer->instruction());
1903 }
1904 EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1905 }
1906
TEST_F(BufferAssignmentTest,InPlaceBuffer)1907 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1908 const char* hlo_text = R"(
1909 HloModule Module
1910
1911 ENTRY main {
1912 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1913 constant.1 = f32[] constant(0)
1914 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1915 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1916 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1917 constant.2 = s32[] constant(128)
1918 add.5 = s32[] add(get-tuple-element.3, constant.2)
1919 constant.3 = s32[] constant(0)
1920 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
1921 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
1922 ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
1923 }
1924 )";
1925
1926 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1927 HloInstruction* parameter =
1928 m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
1929 HloInstruction* dus1 =
1930 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
1931 HloInstruction* dus2 =
1932 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
1933
1934 auto buffers = RunBufferAssignment(m.get());
1935
1936 {
1937 const BufferAllocation& parameter_alloc =
1938 GetTopLevelAllocation(*buffers, parameter);
1939
1940 const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
1941 EXPECT_EQ(parameter_alloc, dus1_alloc);
1942 const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
1943 EXPECT_EQ(parameter_alloc, dus2_alloc);
1944 }
1945 }
1946
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1947 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1948 const char* hlo_text = R"(
1949 HloModule Module
1950
1951 True {
1952 ROOT x.0.1 = f32[] parameter(0)
1953 }
1954
1955 False {
1956 x.0.0 = f32[] parameter(0)
1957 ROOT copy.1 = f32[] copy(x.0.0)
1958 }
1959
1960 ENTRY main {
1961 pred.1.0 = pred[] parameter(0)
1962 constant.1.1 = f32[] constant(56)
1963 copy.2 = f32[] copy(constant.1.1)
1964 constant.1.2 = f32[] constant(12)
1965 ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
1966 true_computation=True, false_computation=False
1967 }
1968 )";
1969
1970 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1971 HloInstruction* constant_1 =
1972 m->entry_computation()->GetInstructionWithName("constant.1.1");
1973 HloInstruction* constant_2 =
1974 m->entry_computation()->GetInstructionWithName("constant.1.2");
1975
1976 auto buffers = RunBufferAssignment(m.get());
1977
1978 {
1979 const BufferAllocation& allocation_for_const_1 =
1980 GetTopLevelAllocation(*buffers, constant_1);
1981 EXPECT_TRUE(allocation_for_const_1.is_constant());
1982 for (const auto& buffer_offset_pair :
1983 allocation_for_const_1.assigned_buffers()) {
1984 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1985 HloOpcode::kCopy);
1986 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1987 HloOpcode::kConditional);
1988 }
1989 }
1990
1991 {
1992 const BufferAllocation& allocation_for_const_2 =
1993 GetTopLevelAllocation(*buffers, constant_2);
1994 EXPECT_TRUE(allocation_for_const_2.is_constant());
1995 for (const auto& buffer_offset_pair :
1996 allocation_for_const_2.assigned_buffers()) {
1997 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1998 HloOpcode::kCopy);
1999 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2000 HloOpcode::kConditional);
2001 }
2002 }
2003 }
2004
2005 class WhileBufferAssignmentTest : public HloTestBase {
2006 protected:
BuildWhileConditionComputation(const string & name)2007 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2008 const string& name) {
2009 auto builder = HloComputation::Builder(name);
2010 builder.AddInstruction(
2011 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2012 auto zero = builder.AddInstruction(
2013 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2014 auto ten = builder.AddInstruction(
2015 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2016 builder.AddInstruction(HloInstruction::CreateCompare(
2017 ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2018 return builder.Build();
2019 }
2020
BuildWhileBodyComputation(const string & name)2021 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2022 const string& name) {
2023 auto builder = HloComputation::Builder(name);
2024 auto loop_state = builder.AddInstruction(
2025 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2026 auto input = builder.AddInstruction(
2027 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2028 auto weights = builder.AddInstruction(
2029 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2030 auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2031 data_shape_, HloOpcode::kMultiply, input, weights));
2032 builder.AddInstruction(
2033 HloInstruction::CreateTuple({input, weights, output}));
2034 return builder.Build();
2035 }
2036
RunBufferAssignment(HloModule * module,int64 alignment=1)2037 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2038 int64 alignment = 1) {
2039 HloSchedule schedule =
2040 ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
2041 return BufferAssigner::Run(
2042 module, absl::make_unique<SequentialHloOrdering>(schedule),
2043 ByteSizeOf,
2044 [alignment](LogicalBuffer::Color) { return alignment; },
2045 /*allocate_buffers_for_constants=*/true)
2046 .ConsumeValueOrDie();
2047 }
2048
ByteSizeOf(const BufferValue & buffer)2049 static int64 ByteSizeOf(const BufferValue& buffer) {
2050 return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2051 }
2052
2053 Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2054 Shape loop_state_shape_ =
2055 ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2056 };
2057
RunCopyInsertion(HloModule * module)2058 static void RunCopyInsertion(HloModule* module) {
2059 CopyInsertion copy_insertion;
2060 EXPECT_IS_OK(copy_insertion.Run(module).status());
2061 }
2062
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2063 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2064 auto module = CreateNewVerifiedModule();
2065 auto builder = HloComputation::Builder("entry");
2066
2067 auto input0 = builder.AddInstruction(
2068 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2069 auto weights0 = builder.AddInstruction(
2070 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2071 auto weights1 = builder.AddInstruction(
2072 HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2073
2074 auto zero = builder.AddInstruction(
2075 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2076 auto output0 = builder.AddInstruction(
2077 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2078 auto output1 = builder.AddInstruction(
2079 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2080
2081 auto cond0 =
2082 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2083 auto body0 =
2084 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2085
2086 auto tuple0 = builder.AddInstruction(
2087 HloInstruction::CreateTuple({input0, weights0, output0}));
2088 auto while0 = builder.AddInstruction(
2089 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2090
2091 auto cond1 =
2092 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2093 auto body1 =
2094 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2095 auto input1 = builder.AddInstruction(
2096 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2097 auto tuple1 = builder.AddInstruction(
2098 HloInstruction::CreateTuple({input1, weights1, output1}));
2099 auto while1 = builder.AddInstruction(
2100 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2101
2102 module->AddEntryComputation(builder.Build());
2103 RunCopyInsertion(module.get());
2104 auto assignment = RunBufferAssignment(module.get());
2105
2106 // Verify 'input0' and read-only use while0{0} alias.
2107 EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
2108 assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
2109 // Verify 'weights0' and read-only use while0{1} alias.
2110 EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
2111 assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
2112 // Verify 'while0{2}' and read-only use while1{0} alias.
2113 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2114 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2115 // Verify 'weights1' and read-only use while1{1} alias.
2116 EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
2117 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2118 }
2119
2120 // Tests that two colocated buffer sets are not merged if an entry parameter
2121 // buffer belongs to either of the colocation sets (b/73267882).
2122 //
2123 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2124 //
2125 // %while.0 body just forwards the init value, so the loop carried variable
2126 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2127 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2128 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2129
2130 const char* module_str = R"(
2131 HloModule test_module
2132
2133 %cond.v0 {
2134 %param = s32[] parameter(0)
2135 ROOT %constant = pred[] constant(true)
2136 }
2137
2138 %cond.v1 {
2139 %param.0 = s32[] parameter(0)
2140 ROOT %constant.0 = pred[] constant(true)
2141 }
2142
2143 %body.v0 {
2144 ROOT %param.1 = s32[] parameter(0)
2145 }
2146
2147 %body.v1 {
2148 %param.2 = s32[] parameter(0)
2149 ROOT add = s32[] add(%param.2, %param.2)
2150 }
2151
2152 ENTRY %test_module {
2153 %param.3 = s32[] parameter(0)
2154 %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2155 %mul = s32[] multiply(%while.0, %while.0)
2156 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2157 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2158 })";
2159
2160 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2161
2162 // Run CopyInsertion and check if the graph constructed above doesn't need
2163 // any copies inserted for BufferAssignment to run.
2164 int64 instruction_count = m->instruction_count();
2165 CopyInsertion copy_insertion;
2166 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2167 ASSERT_EQ(instruction_count, m->instruction_count());
2168
2169 // Get the instructions in the module.
2170 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2171 const HloInstruction* param =
2172 m->entry_computation()->parameter_instruction(0);
2173 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2174 const HloInstruction* while1 = bcast->operand(0);
2175 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2176 const HloInstruction* while0 = while1->operand(0)->operand(0);
2177 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2178
2179 // Run buffer assignment.
2180 auto assignment = RunBufferAssignment(m.get());
2181 TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2182 assignment->GetUniqueSlice(param, {}));
2183 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2184 assignment->GetUniqueSlice(while0, {}));
2185 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2186 assignment->GetUniqueSlice(while1, {}));
2187
2188 // The parameter slice is part of the while0's colocation set (init value),
2189 // but not merged into the while1's colocation set.
2190 EXPECT_EQ(slice_param, slice_while0);
2191 EXPECT_NE(slice_param, slice_while1);
2192 }
2193
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2194 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2195 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2196
2197 const char* module_str = R"(
2198 HloModule test_module
2199
2200 %cond.v0 {
2201 %param = s32[] parameter(0)
2202 ROOT %constant = pred[] constant(true)
2203 }
2204
2205 %cond.v1 {
2206 %param.0 = s32[] parameter(0)
2207 ROOT %constant.0 = pred[] constant(true)
2208 }
2209
2210 %body.v0 {
2211 ROOT %param.1 = s32[] parameter(0)
2212 }
2213
2214 %body.v1 {
2215 %param.2 = s32[] parameter(0)
2216 ROOT add = s32[] add(%param.2, %param.2)
2217 }
2218
2219 ENTRY %test_module {
2220 %constant.42 = s32[] constant(42)
2221 %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2222 %mul = s32[] multiply(%while.0, %while.0)
2223 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2224 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2225 })";
2226
2227 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2228
2229 // Run CopyInsertion and check if the graph constructed above doesn't need
2230 // any copies inserted for BufferAssignment to run.
2231 int64 instruction_count = m->instruction_count();
2232 CopyInsertion copy_insertion;
2233 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2234 ASSERT_EQ(instruction_count, m->instruction_count());
2235
2236 // Get the instructions in the module.
2237 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2238 const HloInstruction* constant =
2239 m->entry_computation()->GetInstructionWithName("constant.42");
2240 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2241 const HloInstruction* while1 = bcast->operand(0);
2242 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2243 const HloInstruction* while0 = while1->operand(0)->operand(0);
2244 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2245
2246 // Run buffer assignment.
2247 auto assignment = RunBufferAssignment(m.get());
2248 TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2249 assignment->GetUniqueSlice(constant, {}));
2250 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2251 assignment->GetUniqueSlice(while0, {}));
2252 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2253 assignment->GetUniqueSlice(while1, {}));
2254
2255 // The constant slice is part of the while0's colocation set (init value), but
2256 // not merged into the while1's colocation set.
2257 EXPECT_EQ(slice_constant, slice_while0);
2258 EXPECT_NE(slice_constant, slice_while1);
2259 }
2260
2261 // Tests that the colocated buffers for while instructions are properly assigned
2262 // during buffer assignment such that the result tuple elements are not assigned
2263 // to the same buffer.
2264 //
2265 // %infeed --> %while.0 --> %while.1 --+
2266 // +-- %tuple
2267 // %zero --> %add --> %while.2 --+
2268 //
2269 // Execution Order:
2270 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2271 //
2272 // The HLO computation used in this test requires specific ordering to expose
2273 // the bug (b/72496031). During buffer assignment, the visitation order of
2274 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2275 // assignment was coalescing the colocated buffers for all 3 while instructions,
2276 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2277 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2278 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2279
2280 // Builds a condition computation: x -> x < 4
2281 auto build_cond = [&]() {
2282 auto builder = HloComputation::Builder("cond");
2283 auto const4 = builder.AddInstruction(
2284 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2285 auto param =
2286 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2287 builder.AddInstruction(
2288 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2289 const4, ComparisonDirection::kLt));
2290 return builder.Build();
2291 };
2292
2293 // Builds a body computation: x -> x + 9
2294 auto build_body = [&]() {
2295 auto builder = HloComputation::Builder("body");
2296 auto const9 = builder.AddInstruction(
2297 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2298 auto param =
2299 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2300 builder.AddInstruction(
2301 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2302 return builder.Build();
2303 };
2304
2305 // Build the entry computation as described in the comment above.
2306 auto module = CreateNewVerifiedModule();
2307 auto builder = HloComputation::Builder("entry");
2308
2309 auto token = builder.AddInstruction(HloInstruction::CreateToken());
2310 auto infeed =
2311 builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2312 auto infeed_data = builder.AddInstruction(
2313 HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2314 auto cond0 = module->AddEmbeddedComputation(build_cond());
2315 auto body0 = module->AddEmbeddedComputation(build_body());
2316 auto while0 = builder.AddInstruction(
2317 HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2318
2319 auto cond1 = module->AddEmbeddedComputation(build_cond());
2320 auto body1 = module->AddEmbeddedComputation(build_body());
2321 auto while1 = builder.AddInstruction(
2322 HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2323
2324 auto zero = builder.AddInstruction(
2325 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2326 auto add = builder.AddInstruction(
2327 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2328 auto cond2 = module->AddEmbeddedComputation(build_cond());
2329 auto body2 = module->AddEmbeddedComputation(build_body());
2330 auto while2 = builder.AddInstruction(
2331 HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2332
2333 auto tuple =
2334 builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2335 module->AddEntryComputation(builder.Build());
2336
2337 // Run CopyInsertion and check if the graph constructed above doesn't need
2338 // any copies inserted for BufferAssignment to run.
2339 int64 instruction_count = module->instruction_count();
2340 CopyInsertion copy_insertion;
2341 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2342 ASSERT_EQ(instruction_count, module->instruction_count());
2343
2344 // Create a sequential order among all the instructions in the entry
2345 // computation, since the issue this test stresses depends on the order the
2346 // nodes are traversed during BufferAssignment.
2347 TF_ASSERT_OK_AND_ASSIGN(
2348 HloSchedule schedule,
2349 ScheduleModule(module.get(), [](const BufferValue& buffer) {
2350 return ShapeUtil::ByteSizeOf(buffer.shape(),
2351 /*pointer_size=*/sizeof(void*));
2352 }));
2353 schedule.set_sequence(
2354 module->entry_computation(),
2355 {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2356 TF_ASSERT_OK(schedule.Verify());
2357
2358 TF_ASSERT_OK_AND_ASSIGN(
2359 auto assignment,
2360 BufferAssigner::Run(
2361 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2362 backend().compiler()->BufferSizeBytesFunction(),
2363 [](LogicalBuffer::Color) { return 1; },
2364 /*allocate_buffers_for_constants=*/true));
2365
2366 // The result tuple elements must be assigned with different buffers.
2367 TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2368 TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2369 EXPECT_NE(slice0, slice1);
2370
2371 // while0 and while1 result buffers must be equal to slice1.
2372 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2373 assignment->GetUniqueSlice(while0, {}));
2374 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2375 assignment->GetUniqueSlice(while1, {}));
2376 EXPECT_EQ(slice1, slice_while0);
2377 EXPECT_EQ(slice1, slice_while1);
2378
2379 // while2 result buffer must be equal to slice0.
2380 TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2381 assignment->GetUniqueSlice(while2, {}));
2382 EXPECT_EQ(slice0, slice_while2);
2383 }
2384
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2385 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2386 auto module = CreateNewVerifiedModule();
2387 auto builder = HloComputation::Builder("entry");
2388
2389 auto input0 = builder.AddInstruction(
2390 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2391 auto weights0 = builder.AddInstruction(
2392 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2393
2394 auto zero = builder.AddInstruction(
2395 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2396 auto output0 = builder.AddInstruction(
2397 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2398
2399 auto cond0 =
2400 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2401 auto body0 =
2402 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2403
2404 auto tuple0 = builder.AddInstruction(
2405 HloInstruction::CreateTuple({input0, weights0, output0}));
2406 auto while0 = builder.AddInstruction(
2407 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2408
2409 auto cond1 =
2410 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2411 auto body1 =
2412 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2413
2414 auto while1 = builder.AddInstruction(
2415 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2416
2417 module->AddEntryComputation(builder.Build());
2418 RunCopyInsertion(module.get());
2419 auto assignment = RunBufferAssignment(module.get());
2420
2421 // while0 and while1 buffers should be completely aligned.
2422 EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2423 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2424 EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2425 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2426 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2427 assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2428 }
2429
TEST_F(BufferAssignmentTest,TwoCalls)2430 TEST_F(BufferAssignmentTest, TwoCalls) {
2431 auto module = CreateNewVerifiedModule();
2432 Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2433 HloComputation* sub_computation;
2434 {
2435 auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2436 auto param = builder.AddInstruction(
2437 HloInstruction::CreateParameter(0, r0f32, "param"));
2438 auto constant1 = builder.AddInstruction(
2439 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2440 auto add = builder.AddInstruction(
2441 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2442 sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2443 }
2444 auto builder = HloComputation::Builder(TestName());
2445 auto constant2 = builder.AddInstruction(
2446 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2447 auto constant3 = builder.AddInstruction(
2448 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2449 auto call1 = builder.AddInstruction(
2450 HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2451 auto call2 = builder.AddInstruction(
2452 HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2453 auto add1 = builder.AddInstruction(
2454 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2455 auto add2 = builder.AddInstruction(
2456 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2457 module->AddEntryComputation(builder.Build(add2));
2458
2459 {
2460 FlattenCallGraph flatten;
2461 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2462 EXPECT_TRUE(result);
2463 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2464 }
2465
2466 RunCopyInsertion(module.get());
2467 auto assignment = RunBufferAssignment(module.get());
2468
2469 EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2470 }
2471
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2472 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2473 const char* hlo_text = R"(
2474 HloModule CallParamCoAllocation
2475
2476 Callee {
2477 param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2478 param1 = s32[20] parameter(1)
2479 ROOT constant = f32[] constant(1)
2480 }
2481
2482 ENTRY Main {
2483 entry_param0 = f32[100] parameter(0)
2484 entry_param1 = s32[20] parameter(1)
2485 custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2486 call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2487 ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2488 }
2489 )";
2490
2491 HloModuleConfig config;
2492 config.set_debug_options(GetDebugOptionsFromFlags());
2493 TF_ASSERT_OK_AND_ASSIGN(auto m,
2494 ParseAndReturnVerifiedModule(hlo_text, config));
2495
2496 auto buffers = RunBufferAssignment(m.get());
2497
2498 HloComputation* main = m->entry_computation();
2499 HloComputation* callee = m->GetComputationWithName("Callee");
2500 EXPECT_NE(callee, nullptr);
2501
2502 HloInstruction* param0 = callee->parameter_instruction(0);
2503 HloInstruction* param1 = callee->parameter_instruction(1);
2504
2505 HloInstruction* entry_param0 = main->parameter_instruction(0);
2506 HloInstruction* entry_param1 = main->parameter_instruction(1);
2507 HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2508
2509 EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2510 GetAllocation(*buffers, param0, {0}));
2511 EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2512 GetAllocation(*buffers, param1, {}));
2513
2514 EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2515 GetAllocation(*buffers, param0, {1}));
2516 EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2517 GetAllocation(*buffers, param0, {1, 0}));
2518 EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2519 GetAllocation(*buffers, param0, {1, 1}));
2520 }
2521
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2522 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2523 absl::string_view module_str = R"(
2524 HloModule test_module
2525
2526 ENTRY %test_module {
2527 %param.0 = s32[1024]{0} parameter(0)
2528 %param.1 = s32[1024]{0} parameter(1)
2529 %mul = s32[1024]{0} multiply(%param.0, %param.1)
2530 %add = s32[1024]{0} add(%mul, %param.0)
2531 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2532 })";
2533
2534 absl::string_view reference_str =
2535 R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2536 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2537 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2538 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2539 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2540 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2541 )";
2542
2543 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2544 HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2545 HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2546 HloInstruction* const mul = FindInstruction(m.get(), "mul");
2547 HloInstruction* const add = FindInstruction(m.get(), "add");
2548 HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2549 // Run buffer assignment.
2550 auto assignment = RunBufferAssignmentWithInstructionSequence(
2551 m.get(), {param0, param1, mul, add, bcast});
2552 const std::string buffer_info_str = assignment->BufferInfoString();
2553
2554 EXPECT_EQ(buffer_info_str, reference_str);
2555 }
2556
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2557 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2558 auto module = CreateNewVerifiedModule();
2559 auto builder = HloComputation::Builder(TestName());
2560
2561 auto zero = builder.AddInstruction(
2562 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2563 auto one = builder.AddInstruction(
2564 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2565
2566 auto input0 = builder.AddInstruction(
2567 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2568 auto weights0 = builder.AddInstruction(
2569 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2570 auto output0 = builder.AddInstruction(
2571 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2572
2573 auto input1 = builder.AddInstruction(
2574 HloInstruction::CreateParameter(2, data_shape_, "input1"));
2575 auto weights1 = builder.AddInstruction(
2576 HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2577 auto output1 = builder.AddInstruction(
2578 HloInstruction::CreateBroadcast(data_shape_, one, {}));
2579
2580 auto cond =
2581 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2582 auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2583
2584 auto tuple0 = builder.AddInstruction(
2585 HloInstruction::CreateTuple({input0, weights0, output0}));
2586 auto tuple1 = builder.AddInstruction(
2587 HloInstruction::CreateTuple({input1, weights1, output1}));
2588
2589 auto while0 = builder.AddInstruction(
2590 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2591 auto while1 = builder.AddInstruction(
2592 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2593
2594 auto gte0 = builder.AddInstruction(
2595 HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2596 auto gte1 = builder.AddInstruction(
2597 HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2598 auto root_add = builder.AddInstruction(
2599 HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2600
2601 module->AddEntryComputation(builder.Build());
2602
2603 {
2604 FlattenCallGraph flatten;
2605 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2606 EXPECT_TRUE(result);
2607 }
2608
2609 RunCopyInsertion(module.get());
2610
2611 HloSchedule schedule =
2612 ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2613
2614 // To trigger b/38494731, we want a specific Hlo schedule for the
2615 // root computation, so we overwrite that entry with a manually
2616 // crafted sequence.
2617 schedule.set_sequence(
2618 module->entry_computation(),
2619 {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2620 input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2621 gte0, gte1, root_add});
2622
2623 // If this ASSERT fails, we constructed a bogus sequence above and this test
2624 // itself is buggy.
2625 TF_ASSERT_OK(schedule.Verify());
2626
2627 auto assignment =
2628 BufferAssigner::Run(
2629 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2630 ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2631 /*allocate_buffers_for_constants=*/true)
2632 .ConsumeValueOrDie();
2633
2634 EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2635 }
2636
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2637 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2638 auto module = CreateNewVerifiedModule();
2639 auto builder = HloComputation::Builder("entry");
2640
2641 auto input0 = builder.AddInstruction(
2642 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2643 auto weights0 = builder.AddInstruction(
2644 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2645
2646 auto zero = builder.AddInstruction(
2647 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2648 auto output0 = builder.AddInstruction(
2649 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2650 auto output1 = builder.AddInstruction(
2651 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2652
2653 auto cond0 =
2654 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2655 auto body0 =
2656 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2657
2658 auto tuple0 = builder.AddInstruction(
2659 HloInstruction::CreateTuple({input0, weights0, output0}));
2660 auto while0 = builder.AddInstruction(
2661 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2662
2663 // Get output of 'while0' and feed as input to 'while1'.
2664 auto while0_out = builder.AddInstruction(
2665 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2666
2667 auto cond1 =
2668 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2669 auto body1 =
2670 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2671
2672 auto tuple1 = builder.AddInstruction(
2673 HloInstruction::CreateTuple({while0_out, weights0, output1}));
2674 auto while1 = builder.AddInstruction(
2675 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2676
2677 // Get output of 'while1' so that it is live out of computation.
2678 auto while1_out = builder.AddInstruction(
2679 HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2680
2681 module->AddEntryComputation(builder.Build());
2682 RunCopyInsertion(module.get());
2683 auto assignment = RunBufferAssignment(module.get());
2684 // Get BufferAllocation for root instruction.
2685 auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2686 .ConsumeValueOrDie()
2687 .allocation();
2688 // Test that root instruction allocation is live out.
2689 EXPECT_TRUE(root_alloc->maybe_live_out());
2690 // Test that root instruction allocation is not an entry parameter.
2691 EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2692 }
2693
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2694 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2695 const char* const hlo_string = R"(
2696 HloModule test
2697
2698 while_body {
2699 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2700 constant.1 = f32[] constant(0)
2701 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2702 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2703 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2704 constant.2 = s32[] constant(128)
2705 add.5 = s32[] add(get-tuple-element.3, constant.2)
2706 constant.3 = s32[] constant(0)
2707 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2708 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2709 ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2710 }
2711
2712 while_condition {
2713 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2714 get-tuple-element = s32[] get-tuple-element(state), index=0
2715 get-tuple-element.1 = s32[] constant(3)
2716 ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2717 }
2718
2719 ENTRY entry_computation {
2720 constant.7 = s32[] constant(0)
2721 copy.1 = s32[] copy(constant.7)
2722 constant.6 = f32[] constant(0)
2723 broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2724 tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2725 while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2726 ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2727 }
2728
2729 )";
2730 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2731 auto module = module_or_status.ConsumeValueOrDie();
2732
2733 RunCopyInsertion(module.get());
2734 auto assignment = RunBufferAssignment(module.get());
2735 // Get BufferAllocation for root instruction.
2736 auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2737 auto dus9_alloc_slice =
2738 assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2739 auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2740 auto dus5_alloc_slice =
2741 assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2742 // Test that the two dynamic-update-slice ops share the same allocation slice.
2743 EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2744 EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2745 }
2746 } // namespace
2747 } // namespace xla
2748