1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/test.h"
42 #include "tensorflow/compiler/xla/test_helpers.h"
43 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/macros.h"
48 
49 namespace xla {
50 namespace {
51 
52 using ::testing::UnorderedElementsAre;
53 
54 // DFS visitor that collects the instructions referenced by a computation
55 // without descending into nested computations, i.e., only from the operands.
56 class InstructionListVisitor : public DfsHloVisitorWithDefault {
57  public:
InstructionListVisitor(const HloInstruction * root)58   explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
59 
DefaultAction(HloInstruction * hlo)60   Status DefaultAction(HloInstruction* hlo) override {
61     // For each instruction, just push it on the list after walking the
62     // operands.
63     instructions_.push_back(hlo);
64     VLOG(0) << "List instruction " << hlo->ToString();
65     return Status::OK();
66   }
67 
GetInstructions()68   std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
69 
70  private:
71   // The instruction root of the computation.
72   const HloInstruction* root_;
73 
74   // The full set of instructions found (may be duplicates, e.g., kParameter).
75   std::vector<const HloInstruction*> instructions_;
76 
77   TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
78 };
79 
GetInstructions(HloInstruction * root)80 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
81   InstructionListVisitor main_list(root);
82   TF_CHECK_OK(root->Accept(&main_list));
83   return main_list.GetInstructions();
84 }
85 
86 class BufferAssignmentTest : public HloTestBase {
87  protected:
~BufferAssignmentTest()88   ~BufferAssignmentTest() override {}
89 
RunBufferAssignment(HloModule * module,int64 alignment=1)90   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
91                                                         int64 alignment = 1) {
92     return BufferAssigner::Run(
93                module, absl::make_unique<DependencyHloOrdering>(module),
94                backend().compiler()->BufferSizeBytesFunction(),
95                [alignment](LogicalBuffer::Color) { return alignment; },
96                /*allocate_buffers_for_constants=*/true)
97         .ConsumeValueOrDie();
98   }
99 
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64 alignment=1)100   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
101       HloModule* module, int64 alignment = 1) {
102     return BufferAssigner::Run(
103                module, absl::make_unique<DependencyHloOrdering>(module),
104                backend().compiler()->BufferSizeBytesFunction(),
105                [alignment](LogicalBuffer::Color) { return alignment; },
106                /*allocate_buffers_for_constants=*/false)
107         .ConsumeValueOrDie();
108   }
109 
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64 alignment=1)110   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
111       HloModule* module, int64 alignment = 1) {
112     absl::flat_hash_set<HloOpcode> must_not_live_out = {HloOpcode::kAdd};
113 
114     return BufferAssigner::Run(
115                module, absl::make_unique<DependencyHloOrdering>(module),
116                backend().compiler()->BufferSizeBytesFunction(),
117                [alignment](LogicalBuffer::Color) { return alignment; },
118                /*allocate_buffers_for_constants=*/false,
119                /*colorer=*/BufferAssigner::DefaultColorer(),
120                /*must_not_live_out=*/must_not_live_out)
121         .ConsumeValueOrDie();
122   }
123 
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64 alignment=1)124   std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
125       HloModule* module, BufferAssigner::Colorer colorer, int64 alignment = 1) {
126     return BufferAssigner::Run(
127                module, absl::make_unique<DependencyHloOrdering>(module),
128                backend().compiler()->BufferSizeBytesFunction(),
129                [alignment](LogicalBuffer::Color) { return alignment; },
130                /*allocate_buffers_for_constants=*/true, std::move(colorer))
131         .ConsumeValueOrDie();
132   }
133 
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64 alignment=1)134   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
135       HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
136       int64 alignment = 1) {
137     HloSchedule schedule(module);
138     schedule.set_sequence(module->entry_computation(), instruction_sequence);
139     return BufferAssigner::Run(
140                module, absl::make_unique<SequentialHloOrdering>(schedule),
141                backend().compiler()->BufferSizeBytesFunction(),
142                [alignment](LogicalBuffer::Color) { return alignment; },
143                /*allocate_buffers_for_constants=*/true)
144         .ConsumeValueOrDie();
145   }
146 
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64 alignment=1)147   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
148       HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
149       int64 alignment = 1) {
150     return BufferAssigner::Run(
151                module, absl::make_unique<DependencyHloOrdering>(module),
152                backend().compiler()->BufferSizeBytesFunction(),
153                [alignment](LogicalBuffer::Color) { return alignment; },
154                /*allocate_buffers_for_constants=*/true,
155                BufferAssigner::DefaultColorer(),
156                /*must_not_live_out=*/{},
157                /*can_share_buffer=*/nullptr, std::move(preset_assignments))
158         .ConsumeValueOrDie();
159   }
160 
161   // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)162   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
163     auto builder = HloComputation::Builder(name);
164     auto param =
165         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
166     auto value = builder.AddInstruction(
167         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
168     builder.AddInstruction(
169         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
170     return builder.Build();
171   }
172 
BuildReduceComputation(const string & name)173   std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
174     auto builder = HloComputation::Builder(name);
175     auto param =
176         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
177     auto param2 =
178         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
179     builder.AddInstruction(
180         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
181     return builder.Build();
182   }
183 
184   // Builds a simple compare-to-limit (x < 4) computation for a While.
185   //
186   // condition:
187   //   const4[s32] -----------------------------------\
188   //                                                   \
189   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
190   //
BuildWhileConditionComputation(const string & name)191   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
192       const string& name) {
193     auto builder = HloComputation::Builder(name);
194     auto const4 = builder.AddInstruction(
195         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
196     auto param = builder.AddInstruction(
197         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
198     auto index = builder.AddInstruction(
199         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
200     builder.AddInstruction(
201         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
202                                       const4, ComparisonDirection::kLt));
203     return builder.Build();
204   }
205 
206   // Builds a simple body computation for a While.
207   //
208   // body:
209   //   constv[f32[4]] --------------------------------------\
210   //                                                         \
211   //                           /--- get-tuple-elementv[1] --- addv ---\
212   //   param[(s32,f32[4])] ---|                                    tuple
213   //                           \--- get-tuple-elementc[0] --- addc ---/
214   //                                                         /
215   //   const1[s32] -----------------------------------------/
216   //
BuildWhileBodyComputation(const string & name)217   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
218       const string& name) {
219     auto builder = HloComputation::Builder(name);
220     auto const1 = builder.AddInstruction(
221         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
222     auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
223         LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
224     auto param = builder.AddInstruction(
225         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
226     auto indexc = builder.AddInstruction(
227         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
228     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
229         indexc->shape(), HloOpcode::kAdd, indexc, const1));
230     auto indexv = builder.AddInstruction(
231         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
232     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
233         constv->shape(), HloOpcode::kAdd, indexv, constv));
234     builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
235     return builder.Build();
236   }
237 
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)238   std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
239       HloOpcode opcode, const string& name) {
240     auto builder = HloComputation::Builder(name);
241     auto param =
242         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
243     builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
244     return builder.Build();
245   }
246 
247   // Verifies that the given instruction hlo has a valid input buffer assigned,
248   // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)249   const BufferAllocation& GetAssignedInputAllocation(
250       const BufferAssignment& buffers, HloInstruction* hlo) {
251     LOG(INFO) << "Checking input: " << hlo->ToString();
252     const BufferAllocation& buffer =
253         *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
254     EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
255     return buffer;
256   }
257 
258   // Verifies that the given instruction hlo has a valid output buffer
259   // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)260   const BufferAllocation& GetAssignedOutputAllocation(
261       const BufferAssignment& buffers, HloInstruction* hlo) {
262     LOG(INFO) << "Checking output: " << hlo->ToString();
263     const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
264     return buffer;
265   }
266 
267   // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)268   const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
269                                         const HloInstruction* hlo,
270                                         const ShapeIndex& index) {
271     return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
272   }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)273   const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
274                                                 const HloInstruction* hlo) {
275     return *buffers.GetUniqueTopLevelSlice(hlo)
276                 .ConsumeValueOrDie()
277                 .allocation();
278   }
279 
280   // Verifies that all instructions in the given instruction list except
281   // kConstant have assigned buffers, and returns their total size. If min_index
282   // and max_index are not nullptr, the minimum and maximum buffer indices in
283   // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)284   int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
285                         const BufferAssignment& buffers) {
286     // Verifies all instructions have buffers, and gets the index ranges.
287     for (const HloInstruction* hlo : instructions) {
288       if (!buffers.HasTopLevelAllocation(hlo)) {
289         // If `hlo` has no assigned buffer, it is either a constant or a nested
290         // parameter.
291         EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
292                     HloOpcode::kParameter == hlo->opcode());
293         continue;
294       }
295     }
296 
297     // Gets the total size of all buffers assigned.
298     int64 total_size = 0;
299     for (auto& allocation : buffers.Allocations()) {
300       total_size += allocation.size();
301     }
302     return total_size;
303   }
304 
305   // Shapes for use in the examples.
306   Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
307   Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
308   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
309   Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
310   Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
311   Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
312   Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
313   Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
314 };
315 
316 // Returns true if the buffers assigned to instructions in "a" are distinct
317 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)318 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
319                             const std::vector<const HloInstruction*>& b,
320                             const BufferAssignment& assignment) {
321   absl::flat_hash_set<BufferAllocation::Slice> a_slices;
322   for (const HloInstruction* instruction : a) {
323     if (assignment.HasTopLevelAllocation(instruction)) {
324       a_slices.insert(
325           assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
326     }
327   }
328 
329   for (const HloInstruction* instruction : b) {
330     if (assignment.HasTopLevelAllocation(instruction)) {
331       if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
332                                 .ConsumeValueOrDie())) {
333         return false;
334       }
335     }
336   }
337   return true;
338 }
339 
340 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)341 TEST_F(BufferAssignmentTest, ScalarConstant) {
342   auto builder = HloComputation::Builder(TestName());
343   auto const0 = builder.AddInstruction(
344       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
345   auto module = CreateNewVerifiedModule();
346   module->AddEntryComputation(builder.Build());
347 
348   {
349     auto buffers = RunBufferAssignment(module.get());
350     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
351   }
352 
353   {
354     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
355     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
356   }
357 }
358 
TEST_F(BufferAssignmentTest,BufferForConst)359 TEST_F(BufferAssignmentTest, BufferForConst) {
360   // Addition of two vector constants: checks that internal constant nodes have
361   // no buffers assigned, and their consumer has a buffer.
362   auto builder = HloComputation::Builder(TestName());
363   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
364       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
365   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
366       LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
367   auto add = builder.AddInstruction(
368       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
369   auto module = CreateNewVerifiedModule();
370   module->AddEntryComputation(builder.Build());
371 
372   {
373     auto buffers = RunBufferAssignment(module.get());
374     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
375     EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
376     GetAssignedOutputAllocation(*buffers, add);
377   }
378   {
379     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
380     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
381     EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
382     GetAssignedOutputAllocation(*buffers, add);
383   }
384 }
385 
TEST_F(BufferAssignmentTest,HasAllocationAt)386 TEST_F(BufferAssignmentTest, HasAllocationAt) {
387   // Create a tuple with non-const and const elements and check that
388   // HasAllocationAt works correctly.
389   auto builder = HloComputation::Builder(TestName());
390   auto param0 = builder.AddInstruction(
391       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
392   auto constant = builder.AddInstruction(
393       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
394   auto negate = builder.AddInstruction(
395       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
396   auto tuple = builder.AddInstruction(
397       HloInstruction::CreateTuple({negate, param0, constant}));
398   auto module = CreateNewVerifiedModule();
399   module->AddEntryComputation(builder.Build());
400 
401   auto buffers = RunBufferAssignment(module.get());
402   // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
403   // reports for the instruction directly.
404   EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
405             buffers->HasAllocationAt(tuple, /*index=*/{}));
406   EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
407             buffers->HasAllocationAt(tuple, /*index=*/{0}));
408   EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
409             buffers->HasAllocationAt(tuple, /*index=*/{1}));
410   EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
411             buffers->HasAllocationAt(tuple, /*index=*/{2}));
412 }
413 
TEST_F(BufferAssignmentTest,BufferForOutputConst)414 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
415   // This computation copies a constant to output.
416   auto builder = HloComputation::Builder(TestName());
417   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
418       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
419   auto copy = builder.AddInstruction(
420       HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
421   auto module = CreateNewVerifiedModule();
422   module->AddEntryComputation(builder.Build());
423 
424   auto buffers = RunBufferAssignment(module.get());
425   // The copy node now has an output buffer.
426   GetAssignedOutputAllocation(*buffers, copy);
427 }
428 
TEST_F(BufferAssignmentTest,Basic)429 TEST_F(BufferAssignmentTest, Basic) {
430   // paramscalar ------- (mul) -- (add) -- (sub)
431   //                     /        /        /
432   // param0[100] -------/        /        /
433   //                            /        /
434   // param1[100] --------------/--------/
435   auto builder = HloComputation::Builder(TestName());
436   auto paramscalar =
437       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
438   auto broadcast = builder.AddInstruction(
439       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
440   auto param0 = builder.AddInstruction(
441       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
442   auto param1 = builder.AddInstruction(
443       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
444   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
445       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
446   auto add = builder.AddInstruction(
447       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
448   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
449       f32vec100_, HloOpcode::kSubtract, add, param1));
450   auto module = CreateNewVerifiedModule();
451   module->AddEntryComputation(builder.Build());
452 
453   auto buffers = RunBufferAssignment(module.get());
454 
455   // Distinct input buffers were assigned for parameters.
456   BufferAllocation paramscalar_buffer =
457       GetAssignedInputAllocation(*buffers, paramscalar);
458   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
459   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
460   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
461   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
462   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
463 
464   // The mul node has a valid buffer assigned, doesn't share with input.
465   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
466   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
467 
468   // The add node can reuse the mul node's buffer.
469   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
470   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
471 
472   // The sub node has a valid output buffer assigned.
473   GetAssignedOutputAllocation(*buffers, sub);
474 }
475 
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)476 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
477   // If an input buffer and output buffer aliases, the input buffer can be
478   // reused for other intermediate results.
479   //
480   // param0[100] ----- (neg1) -- (neg2)
481   //    |                           |
482   //    + -------- Aliased ---------+
483 
484   auto builder = HloComputation::Builder(TestName());
485 
486   auto param = builder.AddInstruction(
487       HloInstruction::CreateParameter(0, f32vec100_, "p0"));
488   auto neg_1 = builder.AddInstruction(
489       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
490   auto neg_2 = builder.AddInstruction(
491       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
492 
493   auto module = CreateNewVerifiedModule();
494   module->AddEntryComputation(builder.Build());
495 
496   TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
497 
498   auto buffers = RunBufferAssignment(module.get());
499 
500   BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
501   BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
502   BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
503 
504   // Everything use one buffer.
505   EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
506   EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
507 }
508 
TEST_F(BufferAssignmentTest,AddCannotReuse)509 TEST_F(BufferAssignmentTest, AddCannotReuse) {
510   // Pass in a special rule to indicate that "add" cannot be live out.
511   //
512   // paramscalar ------- (mul) -- (add) -- (sub)
513   //                     /        /        /
514   // param0[100] -------/        /        /
515   //                            /        /
516   // param1[100] --------------/--------/
517   auto builder = HloComputation::Builder(TestName());
518   auto paramscalar =
519       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
520   auto broadcast = builder.AddInstruction(
521       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
522   auto param0 = builder.AddInstruction(
523       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
524   auto param1 = builder.AddInstruction(
525       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
526   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
527       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
528   auto add = builder.AddInstruction(
529       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
530   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
531       f32vec100_, HloOpcode::kSubtract, add, param1));
532   auto module = CreateNewVerifiedModule();
533   module->AddEntryComputation(builder.Build());
534 
535   auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
536 
537   // Distinct input buffers were assigned for parameters.
538   BufferAllocation paramscalar_buffer =
539       GetAssignedInputAllocation(*buffers, paramscalar);
540   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
541   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
542   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
543   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
544   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
545 
546   // The mul node has a valid buffer assigned, doesn't share with input.
547   const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
548   EXPECT_NE(sub_buffer.index(), param0_buffer.index());
549 
550   // The add node cannot reuse the mul node's buffer since we told buffer
551   // assignment so.
552   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
553   EXPECT_NE(add_buffer.index(), sub_buffer.index());
554 
555   // The sub node has a valid output buffer assigned.
556   GetAssignedOutputAllocation(*buffers, sub);
557 }
558 
TEST_F(BufferAssignmentTest,BasicUniquelyColored)559 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
560   // paramscalar ------- (mul) -- (add) -- (sub)
561   //                     /        /        /
562   // param0[100] -------/        /        /
563   //                            /        /
564   // param1[100] --------------/--------/
565   // The output of each op is colored with a different color, so we can not
566   // share anything.
567   auto builder = HloComputation::Builder(TestName());
568   auto paramscalar =
569       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
570   auto broadcast = builder.AddInstruction(
571       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
572   auto param0 = builder.AddInstruction(
573       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
574   auto param1 = builder.AddInstruction(
575       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
576   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
577       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
578   auto add = builder.AddInstruction(
579       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
580   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
581       f32vec100_, HloOpcode::kSubtract, add, param1));
582   auto module = CreateNewVerifiedModule();
583   module->AddEntryComputation(builder.Build());
584 
585   auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
586     int color = 0;
587     for (HloValue::Id id = 0;
588          id < alias_analysis->dataflow_analysis().values().size(); id++) {
589       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
590       value.set_color(BufferValue::Color(color++));
591     }
592     return Status::OK();
593   };
594 
595   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
596 
597   // Distinct input buffers were assigned for parameters.
598   BufferAllocation paramscalar_buffer =
599       GetAssignedInputAllocation(*buffers, paramscalar);
600   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
601   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
602   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
603   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
604   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
605 
606   // The mul node has a valid buffer assigned, doesn't share with input.
607   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
608   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
609 
610   // The add node can not reuse the mul node's buffer due to coloring.
611   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
612   EXPECT_NE(add_buffer.index(), mul_buffer.index());
613 
614   // The sub node has a valid output buffer assigned.
615   GetAssignedOutputAllocation(*buffers, sub);
616 
617   // Check if the HLO instructions have the correct colors in the layout.
618   EXPECT_EQ(param0->shape().layout().memory_space(), 2);
619   EXPECT_EQ(param1->shape().layout().memory_space(), 3);
620   EXPECT_EQ(mul->shape().layout().memory_space(), 4);
621   EXPECT_EQ(add->shape().layout().memory_space(), 5);
622   EXPECT_EQ(sub->shape().layout().memory_space(), 6);
623 }
624 
TEST_F(BufferAssignmentTest,BasicPartiallyColored)625 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
626   // paramscalar ------- (mul) -- (add) -- (sub)
627   //                     /        /        /
628   // param0[100] -------/        /        /
629   //                            /        /
630   // param1[100] --------------/--------/
631   // The output of the mul and the add have the color 1, and the other buffers
632   // have the color 0, which allows the mul and add to share buffers.
633   auto builder = HloComputation::Builder(TestName());
634   auto paramscalar =
635       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
636   auto broadcast = builder.AddInstruction(
637       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
638   auto param0 = builder.AddInstruction(
639       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
640   auto param1 = builder.AddInstruction(
641       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
642   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
643       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
644   auto add = builder.AddInstruction(
645       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
646   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
647       f32vec100_, HloOpcode::kSubtract, add, param1));
648   auto module = CreateNewVerifiedModule();
649   module->AddEntryComputation(builder.Build());
650 
651   auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
652     for (HloValue::Id id = 0;
653          id < alias_analysis->dataflow_analysis().values().size(); id++) {
654       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
655       auto& buffer = alias_analysis->GetBufferContainingValue(value);
656       for (const auto& alias : buffer.values()) {
657         if (alias->instruction()->opcode() == HloOpcode::kAdd ||
658             alias->instruction()->opcode() == HloOpcode::kMultiply) {
659           value.set_color(LogicalBuffer::Color(1));
660         }
661       }
662       if (!value.has_color()) {
663         value.set_color(LogicalBuffer::Color(0));
664       }
665     }
666     return Status::OK();
667   };
668 
669   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
670 
671   // Distinct input buffers were assigned for parameters.
672   BufferAllocation paramscalar_buffer =
673       GetAssignedInputAllocation(*buffers, paramscalar);
674   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
675   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
676   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
677   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
678   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
679 
680   // The mul node has a valid buffer assigned, doesn't share with input.
681   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
682   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
683 
684   // The add node can reuse the mul node's buffer.
685   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
686   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
687 
688   // The sub node has a valid output buffer assigned.
689   GetAssignedOutputAllocation(*buffers, sub);
690 
691   // Check if the HLO instructions have the correct colors in the layout.
692   EXPECT_EQ(mul->shape().layout().memory_space(), 1);
693   EXPECT_EQ(add->shape().layout().memory_space(), 1);
694   EXPECT_EQ(sub->shape().layout().memory_space(), 0);
695   EXPECT_EQ(param0->shape().layout().memory_space(), 0);
696   EXPECT_EQ(param1->shape().layout().memory_space(), 0);
697 }
698 
TEST_F(BufferAssignmentTest,PresetAssignments)699 TEST_F(BufferAssignmentTest, PresetAssignments) {
700   // paramscalar ------- (mul) -- (add) -- (sub)
701   //                     /        /        /
702   // param0[100] -------/        /        /
703   //                            /        /
704   // param1[100] --------------/--------/
705   // Similar to BasicPartiallyColored, but the color is set in the layout.
706   // The output of the mul and the add have the color 1 and have preset
707   // assignments, and the other buffers have the color 0, which allows the mul
708   // and add to share buffers.
709   auto builder = HloComputation::Builder(TestName());
710   auto paramscalar =
711       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
712   auto broadcast = builder.AddInstruction(
713       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
714   auto param0 = builder.AddInstruction(
715       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
716   auto param1 = builder.AddInstruction(
717       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
718   Shape f32vec100_color1 =
719       ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1);
720   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
721       f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
722   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
723       f32vec100_color1, HloOpcode::kAdd, mul, param1));
724   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
725       f32vec100_, HloOpcode::kSubtract, add, param1));
726   auto module = CreateNewVerifiedModule();
727   module->AddEntryComputation(builder.Build());
728 
729   auto preset_assignments = absl::make_unique<PresetAssignments>();
730   preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
731   preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
732   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
733       ->size = 950;
734 
735   auto buffers = RunBufferAssignmentWithPresetAssignments(
736       module.get(), std::move(preset_assignments));
737 
738   // Distinct input buffers were assigned for parameters.
739   BufferAllocation paramscalar_buffer =
740       GetAssignedInputAllocation(*buffers, paramscalar);
741   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
742   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
743   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
744   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
745   EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
746   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
747   EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
748 
749   // The mul and add use the same preset buffer. Ensure it has the correct color
750   // and offsets.
751   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
752   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
753   EXPECT_EQ(mul_buffer, add_buffer);
754   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
755   EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
756 
757   EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
758   for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
759     if (value_and_offsetsize.first->instruction() == mul) {
760       EXPECT_EQ(value_and_offsetsize.second.offset, 100);
761       EXPECT_EQ(value_and_offsetsize.second.size, 400);
762     } else {
763       EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
764       EXPECT_EQ(value_and_offsetsize.second.offset, 550);
765       EXPECT_EQ(value_and_offsetsize.second.size, 400);
766     }
767   }
768 
769   // The sub node has a valid output buffer assigned.
770   GetAssignedOutputAllocation(*buffers, sub);
771 }
772 
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)773 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
774   // Tests preset assignments when there is no 1-to-1 correspondence between
775   // HloValue and HloBuffer (i.e., a while loop).
776   auto module = CreateNewVerifiedModule();
777   Shape f32vec10_color1 =
778       ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
779   Shape t_s32_f32v10_color1 =
780       ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
781 
782   auto cond_builder = HloComputation::Builder("WhileCond");
783   HloInstruction* cond_param = cond_builder.AddInstruction(
784       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
785   HloInstruction* cond_iter = cond_builder.AddInstruction(
786       HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
787   HloInstruction* cond_limit = cond_builder.AddInstruction(
788       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
789   cond_builder.AddInstruction(
790       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
791                                     cond_limit, ComparisonDirection::kLt));
792   HloComputation* cond_computation =
793       module->AddEmbeddedComputation(cond_builder.Build());
794 
795   auto body_builder = HloComputation::Builder("WhileBody");
796   HloInstruction* body_param = body_builder.AddInstruction(
797       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
798   HloInstruction* body_iter = body_builder.AddInstruction(
799       HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
800   HloInstruction* body_data = body_builder.AddInstruction(
801       HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
802   HloInstruction* body_data_increment = body_builder.AddInstruction(
803       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
804           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
805   HloInstruction* body_data_next =
806       body_builder.AddInstruction(HloInstruction::CreateBinary(
807           f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
808   HloInstruction* body_iter_increment = body_builder.AddInstruction(
809       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
810   HloInstruction* body_iter_next =
811       body_builder.AddInstruction(HloInstruction::CreateBinary(
812           s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
813   body_builder.AddInstruction(
814       HloInstruction::CreateTuple({body_iter_next, body_data_next}));
815   HloComputation* body_computation =
816       module->AddEmbeddedComputation(body_builder.Build());
817 
818   auto builder = HloComputation::Builder(TestName());
819   HloInstruction* iter = builder.AddInstruction(
820       HloInstruction::CreateParameter(0, s32_, "param_iter"));
821   HloInstruction* data = builder.AddInstruction(
822       HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
823   HloInstruction* negate = builder.AddInstruction(
824       HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
825   HloInstruction* tuple =
826       builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
827   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
828       t_s32_f32v10_color1, cond_computation, body_computation, tuple));
829   HloInstruction* while_data = builder.AddInstruction(
830       HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
831   builder.AddInstruction(HloInstruction::CreateBinary(
832       f32vec10_, HloOpcode::kAdd, while_data, data));
833   module->AddEntryComputation(builder.Build());
834 
835   // Set only one preset assignment for while data and its aliases.
836   auto preset_assignments = absl::make_unique<PresetAssignments>();
837   preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
838   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
839       ->size = 140;
840 
841   auto buffers = RunBufferAssignmentWithPresetAssignments(
842       module.get(), std::move(preset_assignments));
843 
844   // All assigned buffers are aliased so they should have the same offset and
845   // size.
846   const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
847   EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
848   for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
849     EXPECT_EQ(value_and_offsetsize.second.offset, 100);
850     EXPECT_EQ(value_and_offsetsize.second.size, 40);
851     EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
852   }
853 }
854 
TEST_F(BufferAssignmentTest,MultipleUsersForNode)855 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
856   // This is similar to the Basic test, with the difference that (sub) is
857   // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
858   // (add)'s output.
859   //
860   // paramscalar -------\     /-----------\
861   //                     \   /             \
862   // param0[100] ------- (mul) -- (add) -- (sub)
863   //                              /
864   // param1[100] ----------------/
865   //
866   auto builder = HloComputation::Builder(TestName());
867   auto paramscalar =
868       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
869   auto broadcast = builder.AddInstruction(
870       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
871   auto param0 = builder.AddInstruction(
872       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
873   auto param1 = builder.AddInstruction(
874       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
875   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
876       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
877   auto add = builder.AddInstruction(
878       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
879   auto sub = builder.AddInstruction(
880       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
881   auto module = CreateNewVerifiedModule();
882   module->AddEntryComputation(builder.Build());
883 
884   auto buffers = RunBufferAssignment(module.get());
885 
886   // Input buffers were assigned for parameters.
887   BufferAllocation paramscalar_buffer =
888       GetAssignedInputAllocation(*buffers, paramscalar);
889   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
890   BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
891   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
892   EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
893   EXPECT_NE(param0_buffer.index(), param1_index.index());
894 
895   // The mul node had a buffer allocated.
896   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
897 
898   // Now the add node can't reuse the mul node's buffer.
899   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
900   EXPECT_NE(add_buffer.index(), mul_buffer.index());
901 
902   // Log size information for inspection.
903   const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
904   int64 size0 = ValidateBuffers(level0, *buffers);
905   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
906             << " for " << level0.size() << " instructions; "
907             << "total buffer size " << size0;
908 }
909 
TEST_F(BufferAssignmentTest,TrivialMap)910 TEST_F(BufferAssignmentTest, TrivialMap) {
911   // This tests a trivial x+1 map as the only operation.
912   //
913   // param0[100x10] ---> (map x+1)
914   //
915   // Builds the map function.
916   auto module = CreateNewVerifiedModule();
917   auto map_computation =
918       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
919   auto inner_last = map_computation->root_instruction();
920 
921   // Creates the main kernel and verifies instruction counts.
922   auto builder = HloComputation::Builder(TestName());
923   auto param0 = builder.AddInstruction(
924       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
925   auto map = builder.AddInstruction(
926       HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
927   module->AddEntryComputation(builder.Build());
928 
929   const std::vector<const HloInstruction*> level0 = GetInstructions(map);
930   EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
931   const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
932   EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
933 
934   // Assigns buffers and fetches sizes.
935   auto buffers = RunBufferAssignment(module.get());
936   int64 size0 = ValidateBuffers(level0, *buffers);
937   int64 size1 = ValidateBuffers(level1, *buffers);
938 
939   // Both algorithms assign the map's buffer before processing the embedded
940   // computation, so we can verify that the buffers aren't shared between them
941   // by checking:
942   EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
943       << "Reuse between main kernel and embedded mapping.";
944 
945   // An input buffer was assigned for the parameter.
946   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
947 
948   // An output buffer was assigned for the map.
949   BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
950   EXPECT_NE(param0_buffer.index(), map_buffer.index());
951 
952   // The final computation node of the map is an add of an f32 param and a
953   // constant.
954   EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
955   const BufferAllocation& inner_add_buffer =
956       GetTopLevelAllocation(*buffers, inner_last);
957   EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
958 
959   // Log size information for inspection.
960   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
961             << " for " << level0.size() + level1.size() << " instructions; "
962             << "total buffer size " << size0 + size1;
963 }
964 
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)965 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
966   // Make sure that the input buffer of a reduce cannot be reused for its
967   // output.  (Reuse is not safe in the general case, as it reshapes and some
968   // out-of-order reductions could overwrite an element before a use.)
969   //
970   // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
971   auto module = CreateNewVerifiedModule();
972   auto reduce_computation =
973       module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
974 
975   auto builder = HloComputation::Builder(TestName());
976   auto param0 = builder.AddInstruction(
977       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
978   auto exp1 = builder.AddInstruction(
979       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
980   auto exp2 = builder.AddInstruction(
981       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
982   auto const0 = builder.AddInstruction(
983       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
984   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
985       /*shape=*/f32vec10_,
986       /*operand=*/exp2,
987       /*init_value=*/const0,
988       /*dimensions_to_reduce=*/{0}, reduce_computation));
989   auto exp3 = builder.AddInstruction(
990       HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
991 
992   module->AddEntryComputation(builder.Build());
993 
994   auto buffers = RunBufferAssignment(module.get());
995   const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
996   ValidateBuffers(instrs, *buffers);
997 
998   const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
999   const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1000   const BufferAllocation& reduce_buffer =
1001       GetTopLevelAllocation(*buffers, reduce);
1002 
1003   // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1004   // checking.
1005   EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1006 
1007   // The buffer of exp2 cannot be used for reduce, even though it's the only
1008   // operand.
1009   EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1010 }
1011 
TEST_F(BufferAssignmentTest,ExampleWhile)1012 TEST_F(BufferAssignmentTest, ExampleWhile) {
1013   // This tests a While loop example from the ir_semantics document.
1014   //
1015   // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1016   // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1017   //
1018   // const3[s32] -------\
1019   // const4[f32[4]] --- tuple --- while[condition, body]
1020   //
1021   // Builds the nested condition and body.
1022   auto module = CreateNewVerifiedModule();
1023   auto condition_computation =
1024       module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1025   auto body_computation =
1026       module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1027 
1028   // Creates the main kernel and verifies instruction counts.
1029   auto builder = HloComputation::Builder(TestName());
1030   auto const3 = builder.AddInstruction(
1031       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1032   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1033       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1034   auto tuple =
1035       builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1036   auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1037       t_s32_f32v4_, condition_computation, body_computation, tuple));
1038   module->AddEntryComputation(builder.Build());
1039 
1040   const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1041   EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1042   const std::vector<const HloInstruction*> levelc =
1043       GetInstructions(condition_computation->root_instruction());
1044   EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1045   const std::vector<const HloInstruction*> levelb =
1046       GetInstructions(body_computation->root_instruction());
1047   EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1048 
1049   // Assigns buffers and fetches sizes.
1050   auto buffers = RunBufferAssignment(module.get());
1051   int64 size0 = ValidateBuffers(level0, *buffers);
1052   int64 sizec = ValidateBuffers(levelc, *buffers);
1053   int64 sizeb = ValidateBuffers(levelb, *buffers);
1054 
1055   // BufferAssignment will assign a single allocation for the following
1056   // instructions: while, while.cond.param, while.body.param, while.body.result.
1057   EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1058       << "Should be reuse between main kernel and embedded condition.";
1059   EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1060       << "Should be reuse between embedded condition and body.";
1061   // Expect buffer reuse between main kernel and body computation.
1062   EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1063       << "Should be reuse between main kernel and embedded body.";
1064 
1065   // The final computation node of the while body is a tuple of s32 and
1066   // f32[4] adds.
1067   HloInstruction* body_root = body_computation->root_instruction();
1068   EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1069 
1070   // Check that buffer for each subshape of 'while_op' shares allocation with
1071   // corresponding buffer from while body computation at same index.
1072   ShapeUtil::ForEachSubshape(
1073       while_op->shape(),
1074       [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1075                                             const ShapeIndex& index) {
1076         auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1077         auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1078         EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1079       });
1080 
1081   // Log size information for inspection.
1082   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1083             << " for " << level0.size() + levelc.size() + levelb.size()
1084             << " instructions; total buffer size " << size0 + sizec + sizeb;
1085 }
1086 
TEST_F(BufferAssignmentTest,ExampleConditional)1087 TEST_F(BufferAssignmentTest, ExampleConditional) {
1088   auto module = CreateNewVerifiedModule();
1089   auto true_computation = module->AddEmbeddedComputation(
1090       BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1091   auto false_computation = module->AddEmbeddedComputation(
1092       BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1093 
1094   auto builder = HloComputation::Builder(TestName());
1095   auto pred = builder.AddInstruction(
1096       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1097   auto const1 = builder.AddInstruction(
1098       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1099   auto const2 = builder.AddInstruction(
1100       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1101   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1102       r0f32_, pred, const1, true_computation, const2, false_computation));
1103   module->AddEntryComputation(builder.Build());
1104 
1105   const std::vector<const HloInstruction*> conditional_instrs =
1106       GetInstructions(conditional);
1107   const std::vector<const HloInstruction*> true_instrs =
1108       GetInstructions(true_computation->root_instruction());
1109   const std::vector<const HloInstruction*> false_instrs =
1110       GetInstructions(false_computation->root_instruction());
1111   EXPECT_EQ(4, conditional_instrs.size());
1112   EXPECT_EQ(2, true_instrs.size());
1113   EXPECT_EQ(2, false_instrs.size());
1114 
1115   auto buffers = RunBufferAssignment(module.get());
1116   ValidateBuffers(conditional_instrs, *buffers);
1117   ValidateBuffers(true_instrs, *buffers);
1118   ValidateBuffers(false_instrs, *buffers);
1119 
1120   EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1121       << "Should be reuse between conditional and true computation.";
1122   EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1123       << "Should be reuse between conditional and false computation.";
1124   EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1125       << "Should be reuse between true and false computations.";
1126 
1127   const BufferAllocation& conditional_buffer =
1128       GetTopLevelAllocation(*buffers, conditional);
1129   const BufferAllocation& true_buffer =
1130       GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1131   const BufferAllocation& false_buffer =
1132       GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1133   EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1134   EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1135 }
1136 
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1137 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1138   // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1139   auto builder = HloComputation::Builder(TestName());
1140   auto param0 = builder.AddInstruction(
1141       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1142   auto exp1 = builder.AddInstruction(
1143       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1144   auto tanh = builder.AddInstruction(
1145       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1146   auto exp2 = builder.AddInstruction(
1147       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1148   auto neg = builder.AddInstruction(
1149       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1150 
1151   auto module = CreateNewVerifiedModule();
1152   module->AddEntryComputation(builder.Build());
1153   auto assignment = RunBufferAssignment(module.get());
1154 
1155   // tanh and exp2 can reuse exp1's buffer
1156   EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1157   auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1158   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1159   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1160   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1161 }
1162 
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1163 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1164   // This computation is a chain of operations which decreases in buffer size
1165   // (via slice) then increases in size (via broadcast):
1166   //
1167   // param ---> (negate) ---> (slice) ---> (broadcast)
1168   //
1169   // The negate should share a buffer with broadcast.
1170   auto builder = HloComputation::Builder(TestName());
1171   auto param0 = builder.AddInstruction(
1172       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1173   auto negate = builder.AddInstruction(
1174       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1175   auto slice = builder.AddInstruction(
1176       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1177   auto broadcast = builder.AddInstruction(
1178       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1179 
1180   auto module = CreateNewVerifiedModule();
1181   module->AddEntryComputation(builder.Build());
1182   auto assignment = RunBufferAssignment(module.get());
1183 
1184   // negate and broadcast should share a buffer.
1185   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1186   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1187   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1188 
1189   // Slice should have its own buffer.
1190   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1191 }
1192 
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1193 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1194   // This computation is identical to that in ReuseNonOperandBuffer, but the
1195   // negate value is live until the end of the computation (due to it being an
1196   // operand of the output tuple) preventing reuse.
1197   //
1198   // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1199   //                  \-----------------------------------/
1200   //
1201   // The negate should not share a buffer with broadcast.
1202   auto builder = HloComputation::Builder(TestName());
1203   auto param0 = builder.AddInstruction(
1204       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1205   auto negate = builder.AddInstruction(
1206       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1207   auto slice = builder.AddInstruction(
1208       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1209   auto broadcast = builder.AddInstruction(
1210       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1211   builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1212 
1213   auto module = CreateNewVerifiedModule();
1214   module->AddEntryComputation(builder.Build());
1215   auto assignment = RunBufferAssignment(module.get());
1216 
1217   // The instructions should not share buffers.
1218   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1219             GetTopLevelAllocation(*assignment, negate));
1220   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1221             GetTopLevelAllocation(*assignment, slice));
1222   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1223             GetTopLevelAllocation(*assignment, slice));
1224 }
1225 
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1226 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1227   // This computation is identical to that in ReuseNonOperandBuffer, but the
1228   // negate value is placed into a tuple which lives to the end of the
1229   // computation. This extends the live range of negate's buffer preventing
1230   // reuse due to buffer aliasing.
1231   //
1232   // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1233   //                              \-----------------------------------/
1234   //
1235   // The negate should not share a buffer with broadcast.
1236   auto builder = HloComputation::Builder(TestName());
1237   auto param0 = builder.AddInstruction(
1238       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1239   auto negate = builder.AddInstruction(
1240       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1241   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1242   auto tuple_element = builder.AddInstruction(
1243       HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1244   auto slice = builder.AddInstruction(
1245       HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1246   auto broadcast = builder.AddInstruction(
1247       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1248   builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1249 
1250   auto module = CreateNewVerifiedModule();
1251   module->AddEntryComputation(builder.Build());
1252   auto assignment = RunBufferAssignment(module.get());
1253 
1254   // The instructions should not share buffers.
1255   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1256             GetTopLevelAllocation(*assignment, negate));
1257   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1258             GetTopLevelAllocation(*assignment, slice));
1259   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1260             GetTopLevelAllocation(*assignment, slice));
1261 }
1262 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1263 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1264   // This computation is very similar to ReuseNonOperandBuffer except the
1265   // broadcast has a smaller output than the negate. This should block reuse of
1266   // negate's buffer by broadcast because the output buffer(s) of a computation
1267   // should be exactly sized for the value.
1268   //
1269   // param ---> (negate) ---> (slice) ---> (broadcast)
1270   //
1271   // Neither negate nor slice may share a buffer with broadcast.
1272   auto builder = HloComputation::Builder(TestName());
1273   auto param0 = builder.AddInstruction(
1274       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1275   // Negate output is 100 elements.
1276   auto negate = builder.AddInstruction(
1277       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1278   // Slice output is 10 elements.
1279   auto slice = builder.AddInstruction(
1280       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1281   // Broadcast output is 40 elements.
1282   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1283       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1284 
1285   auto module = CreateNewVerifiedModule();
1286   module->AddEntryComputation(builder.Build());
1287   auto assignment = RunBufferAssignment(module.get());
1288 
1289   // The broadcast output buffer cannot be shared.
1290   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1291             GetTopLevelAllocation(*assignment, negate));
1292   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1293             GetTopLevelAllocation(*assignment, slice));
1294 }
1295 
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1296 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1297   // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1298   // output is exactly the same size as the negate (rather than being
1299   // smaller). This enables reuse of negate's buffer by the broadcast because
1300   // the output buffer will be sized exactly to its value.
1301   //
1302   // param ---> (negate) ---> (slice) ---> (broadcast)
1303   //
1304   // The negate should *not* share a buffer with broadcast.
1305   auto builder = HloComputation::Builder(TestName());
1306   auto param0 = builder.AddInstruction(
1307       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1308   // Negate output is 100 elements.
1309   auto negate = builder.AddInstruction(
1310       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1311   auto slice = builder.AddInstruction(
1312       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1313   // Broadcast output is 40 elements.
1314   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1315       ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1316 
1317   auto module = CreateNewVerifiedModule();
1318   module->AddEntryComputation(builder.Build());
1319   auto assignment = RunBufferAssignment(module.get());
1320 
1321   // negate and broadcast should share a buffer.
1322   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1323   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1324   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1325 
1326   // Slice should have its own buffer.
1327   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1328 }
1329 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1330 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1331   // This computation is very similar to ReuseNonOperandBuffer except the
1332   // broadcast has a smaller output than the negate, and the broadcast is
1333   // contained in the computation output as a tuple element. This should block
1334   // reuse of the negate's buffer by the broadcast because the output buffer(s)
1335   // of a computation should be exactly sized for the value. This includes those
1336   // buffers aliased in the output (eg, contained as tuple elements).
1337   //
1338   // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1339   //
1340   // Neither negate nor slice may share a buffer with broadcast.
1341   auto builder = HloComputation::Builder(TestName());
1342   auto param0 = builder.AddInstruction(
1343       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1344   // Negate output is 100 elements.
1345   auto negate = builder.AddInstruction(
1346       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1347   // Slice output is 10 elements.
1348   auto slice = builder.AddInstruction(
1349       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1350   // Broadcast output is 40 elements.
1351   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1352       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1353   builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1354 
1355   auto module = CreateNewVerifiedModule();
1356   module->AddEntryComputation(builder.Build());
1357   auto assignment = RunBufferAssignment(module.get());
1358 
1359   // The broadcast output buffer cannot be shared.
1360   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1361             GetTopLevelAllocation(*assignment, negate));
1362   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1363             GetTopLevelAllocation(*assignment, slice));
1364 }
1365 
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1366 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1367   // Verify that buffers for embedded computations are properly marked as
1368   // thread-local and that embedded parameters are not marked as
1369   // is_entry_computation_parameter.
1370   auto module = CreateNewVerifiedModule();
1371   auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1372   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1373 
1374   // Create a scalar computation to use in a map.
1375   auto map_builder = HloComputation::Builder(TestName() + "_map");
1376   auto map_param = map_builder.AddInstruction(
1377       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1378   auto map_root = map_builder.AddInstruction(
1379       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1380   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1381 
1382   // Create a vector computation to use in a kCall.
1383   auto call_builder = HloComputation::Builder(TestName() + "_call");
1384   auto call_param = call_builder.AddInstruction(
1385       HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1386   auto call_root = call_builder.AddInstruction(
1387       HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1388   auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1389 
1390   // Create entry computation which kCalls call_computation and then calls map
1391   // with map_computation on the result.
1392   auto builder = HloComputation::Builder(TestName());
1393   auto param = builder.AddInstruction(
1394       HloInstruction::CreateParameter(0, vec_shape, "param"));
1395   auto call = builder.AddInstruction(
1396       HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1397   auto map = builder.AddInstruction(
1398       HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1399   module->AddEntryComputation(builder.Build());
1400 
1401   auto assignment = RunBufferAssignment(module.get());
1402 
1403   // Allocations for the map computation should be thread-local and not
1404   // live-out.
1405   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1406   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1407   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1408   EXPECT_TRUE(map_param_alloc.is_thread_local());
1409 
1410   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1411   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1412   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1413   EXPECT_TRUE(map_root_alloc.is_thread_local());
1414 
1415   // Allocations for the call computation should not be thread-local.
1416   auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1417   EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1418   EXPECT_FALSE(call_param_alloc.maybe_live_out());
1419   EXPECT_FALSE(call_param_alloc.is_thread_local());
1420 
1421   auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1422   EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1423   EXPECT_FALSE(call_root_alloc.is_thread_local());
1424 
1425   // Entry computation allocations can be marked liveout and
1426   // is_entry_computation_parameter.
1427   auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1428   EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1429   EXPECT_FALSE(param_alloc.maybe_live_out());
1430   EXPECT_FALSE(param_alloc.is_thread_local());
1431 
1432   auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1433   EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1434   EXPECT_TRUE(map_alloc.maybe_live_out());
1435   EXPECT_FALSE(map_alloc.is_thread_local());
1436 }
1437 
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1438 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1439   // Test a computation that returns a tuple parameter.
1440   auto builder = HloComputation::Builder(TestName());
1441   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1442       0,
1443       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1444                                  ShapeUtil::MakeShape(F32, {}),
1445                                  ShapeUtil::MakeShape(S32, {42})}),
1446       "param0"));
1447 
1448   auto module = CreateNewVerifiedModule();
1449   module->AddEntryComputation(builder.Build());
1450   auto assignment = RunBufferAssignment(module.get());
1451 
1452   // There should be four allocations: one for vector of pointers, and one for
1453   // each tuple element.
1454   EXPECT_EQ(4, assignment->Allocations().size());
1455 
1456   // Verify each buffer allocation is marked as an entry computation parameter
1457   // and is liveout.
1458   ShapeUtil::ForEachSubshape(
1459       tuple_param->shape(),
1460       [this, &assignment, tuple_param](const Shape& /*subshape*/,
1461                                        const ShapeIndex& index) {
1462         auto allocation = GetAllocation(*assignment, tuple_param, index);
1463         EXPECT_TRUE(allocation.is_entry_computation_parameter());
1464         EXPECT_EQ(0, allocation.parameter_number());
1465         EXPECT_TRUE(allocation.maybe_live_out());
1466       });
1467 }
1468 
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1469 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1470   // Test a computation which returns a GetElementTuple of a nested tuple
1471   // parameter.
1472   auto builder = HloComputation::Builder(TestName());
1473   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1474       0,
1475       ShapeUtil::MakeTupleShape(
1476           {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1477            ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1478                                       ShapeUtil::MakeShape(S32, {101})})}),
1479       "param0"));
1480   auto tuple_element =
1481       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1482           ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1483 
1484   auto module = CreateNewVerifiedModule();
1485   module->AddEntryComputation(builder.Build());
1486   auto assignment = RunBufferAssignment(module.get());
1487 
1488   // Only some of the elements of the input param are liveout.
1489   EXPECT_FALSE(
1490       GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1491   // Tuple element at index={1} is live out because GetTupleElement({1})
1492   // forwards a pointer to this allocation (instead of defining its own buffer).
1493   EXPECT_TRUE(
1494       GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1495   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1496                   .maybe_live_out());
1497   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1498                   .maybe_live_out());
1499 
1500   // The GetTupleElement output is liveout.
1501   EXPECT_TRUE(
1502       GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1503 
1504   // Verify that the GetTupleElement allocations of its elements match the
1505   // corresponding tuple parameter allocations because they alias.
1506   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1507             GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1508   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1509             GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1510 
1511   // GetTupleElement forwards a pointer to its underlying buffer, so verify
1512   // that it has the same allocation than the corresponding parameter element.
1513   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1514             GetTopLevelAllocation(*assignment, tuple_element));
1515 }
1516 
1517 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1518 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1519 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1520   // Test that a tuple constant which is forwarded to the computation output
1521   // is properly handled.
1522   auto builder = HloComputation::Builder(TestName());
1523   Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1524                         LiteralUtil::CreateR0<int64>(1)};
1525   builder.AddInstruction(HloInstruction::CreateConstant(
1526       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1527 
1528   auto module = CreateNewVerifiedModule();
1529   module->AddEntryComputation(builder.Build());
1530   auto assignment = RunBufferAssignment(module.get());
1531 
1532   EXPECT_EQ(3, assignment->Allocations().size());
1533 }
1534 
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1535 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1536   // Test a computation which returns a tuple custom call value.
1537   auto builder = HloComputation::Builder(TestName());
1538   auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1539       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1540                                  ShapeUtil::MakeShape(S32, {101})}),
1541       /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1542   auto module = CreateNewVerifiedModule();
1543   module->AddEntryComputation(builder.Build());
1544   auto assignment = RunBufferAssignment(module.get());
1545 
1546   EXPECT_EQ(3, assignment->Allocations().size());
1547   EXPECT_TRUE(
1548       GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1549   EXPECT_TRUE(
1550       GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1551   EXPECT_TRUE(
1552       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1553 }
1554 
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1555 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1556   // Test a computation which returns a tuple call value.
1557   auto module = CreateNewVerifiedModule();
1558   auto elem_shape = f32vec4_;
1559   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1560 
1561   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1562   auto sub_param = sub_builder.AddInstruction(
1563       HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1564   auto sub_tuple =
1565       sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1566   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1567 
1568   auto builder = HloComputation::Builder(TestName());
1569   auto param = builder.AddInstruction(
1570       HloInstruction::CreateParameter(0, elem_shape, "param"));
1571   auto call = builder.AddInstruction(
1572       HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1573   module->AddEntryComputation(builder.Build());
1574 
1575   auto assignment = RunBufferAssignment(module.get());
1576 
1577   EXPECT_EQ(2, assignment->Allocations().size());
1578   // Buffers for call are colocated with the sub-computation.
1579   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1580             GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1581   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1582             GetAllocation(*assignment, sub_param, /*index=*/{}));
1583 
1584   // The parameter isn't aliased with the result tuple, but it is aliased with
1585   // the call operand.
1586   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1587             GetTopLevelAllocation(*assignment, sub_tuple));
1588   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1589             GetTopLevelAllocation(*assignment, sub_param));
1590 }
1591 
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1592 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1593   // Test a chain of calls with tuple output. The chain looks like:
1594   // A: call(B, tuple(param))
1595   // B: call(C, param)
1596   // C: call(D, param)
1597   // D: param
1598   auto module = CreateNewVerifiedModule();
1599   auto elem_shape = f32vec4_;
1600   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1601 
1602   auto d_builder = HloComputation::Builder(TestName() + "_d");
1603   auto d_param = d_builder.AddInstruction(
1604       HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1605   auto d_computation = d_builder.Build();
1606 
1607   auto c_builder = HloComputation::Builder(TestName() + "_c");
1608   auto c_param = c_builder.AddInstruction(
1609       HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1610   auto c_call = c_builder.AddInstruction(
1611       HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1612   auto c_computation = c_builder.Build();
1613 
1614   auto b_builder = HloComputation::Builder(TestName() + "_b");
1615   auto b_param = b_builder.AddInstruction(
1616       HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1617   auto b_call = b_builder.AddInstruction(
1618       HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1619   auto b_computation = b_builder.Build();
1620 
1621   auto a_builder = HloComputation::Builder(TestName());
1622   auto a_param = a_builder.AddInstruction(
1623       HloInstruction::CreateParameter(0, elem_shape, "param"));
1624   auto a_tuple =
1625       a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1626   auto a_call = a_builder.AddInstruction(
1627       HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1628   auto a_computation = a_builder.Build();
1629 
1630   // Add the computations in an order that doesn't match the dependency
1631   // post-order, to shake out more possible bugs.
1632   module->AddEmbeddedComputation(std::move(d_computation));
1633   module->AddEmbeddedComputation(std::move(c_computation));
1634   module->AddEntryComputation(std::move(a_computation));
1635   module->AddEmbeddedComputation(std::move(b_computation));
1636 
1637   auto assignment = RunBufferAssignment(module.get());
1638 
1639   // Buffers for call are colocated with the sub-computations.
1640   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1641             GetAllocation(*assignment, b_call, /*index=*/{}));
1642   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1643             GetAllocation(*assignment, c_call, /*index=*/{}));
1644   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1645             GetAllocation(*assignment, d_param, /*index=*/{}));
1646   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1647             GetAllocation(*assignment, b_call, /*index=*/{0}));
1648   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1649             GetAllocation(*assignment, c_call, /*index=*/{0}));
1650   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1651             GetAllocation(*assignment, d_param, /*index=*/{0}));
1652 
1653   EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1654   EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1655   EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1656 
1657   EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1658             GetAllocation(*assignment, c_param, /*index=*/{0}));
1659   EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1660             GetAllocation(*assignment, d_param, /*index=*/{0}));
1661 }
1662 
TEST_F(BufferAssignmentTest,BitcastAsOutput)1663 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1664   // Test a computation which returns a bitcast value.
1665   auto builder = HloComputation::Builder(TestName());
1666   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1667       0, ShapeUtil::MakeShape(F32, {42}), "param"));
1668   auto bitcast = builder.AddInstruction(
1669       HloInstruction::CreateBitcast(param->shape(), param));
1670 
1671   auto module = CreateNewVerifiedModule();
1672   module->AddEntryComputation(builder.Build());
1673   auto assignment = RunBufferAssignment(module.get());
1674 
1675   // Bitcast should get the same allocation as the param.
1676   EXPECT_EQ(1, assignment->Allocations().size());
1677   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1678             GetTopLevelAllocation(*assignment, bitcast));
1679 }
1680 
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1681 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1682   // Test a computation with an output that has an ambiguous points-to set.
1683   // This is constructed using a select among tuple shapes.
1684   auto builder = HloComputation::Builder(TestName());
1685   auto tuple_shape =
1686       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1687 
1688   auto tuple_param0 = builder.AddInstruction(
1689       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1690   auto tuple_param1 = builder.AddInstruction(
1691       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1692   auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1693       2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1694   auto select = builder.AddInstruction(
1695       HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1696                                     pred_param, tuple_param0, tuple_param1));
1697 
1698   auto module = CreateNewVerifiedModule();
1699   module->AddEntryComputation(builder.Build());
1700   auto assignment = RunBufferAssignment(module.get());
1701 
1702   // Select shallow copies one of its operands so it defines its own top-level
1703   // buffer and receives its own allocation.
1704   auto select_alloc = GetTopLevelAllocation(*assignment, select);
1705   EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1706   EXPECT_EQ(select,
1707             select_alloc.assigned_buffers().begin()->first->instruction());
1708 
1709   // The buffer for the tuple element of the select is forwarded from one its
1710   // operands which cannot be determined statically. Therefore its slices
1711   // should include the slices of both of the elements in the parameters.
1712   auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1713   EXPECT_EQ(2, element_slices.size());
1714   EXPECT_THAT(element_slices,
1715               UnorderedElementsAre(
1716                   assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1717                       .ConsumeValueOrDie(),
1718                   assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1719                       .ConsumeValueOrDie()));
1720 }
1721 
1722 // TODO(b/34669761): Remove this test when buffers are allowed to share
1723 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1724 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1725   // Test a computation that returns a tuple parameter.
1726   auto builder = HloComputation::Builder(TestName());
1727   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1728   auto param = builder.AddInstruction(
1729       HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1730   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1731   auto tuple_element = builder.AddInstruction(
1732       HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1733   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1734       scalar_shape, HloOpcode::kCopy, tuple_element));
1735 
1736   auto module = CreateNewVerifiedModule();
1737   module->AddEntryComputation(builder.Build());
1738   auto assignment = RunBufferAssignment(module.get());
1739 
1740   // There should be no buffer reuse. The copy should not reuse the tuple
1741   // buffer.
1742   EXPECT_EQ(3, assignment->Allocations().size());
1743   EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1744             GetTopLevelAllocation(*assignment, copy));
1745 }
1746 
TEST_F(BufferAssignmentTest,OneTempAllocation)1747 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1748   // Test a computation that requires multiple temp buffers, and ensure they
1749   // are combined into a single allocation.
1750   auto builder = HloComputation::Builder(TestName());
1751   Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1752   Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1753   Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1754   Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1755   Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1756 
1757   // There should be separate temp buffers for dot_ab and dot_bc.
1758   auto param_a = builder.AddInstruction(
1759       HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1760   auto param_b = builder.AddInstruction(
1761       HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1762   auto param_c = builder.AddInstruction(
1763       HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1764   DotDimensionNumbers dot_dnums;
1765   dot_dnums.add_lhs_contracting_dimensions(1);
1766   dot_dnums.add_rhs_contracting_dimensions(0);
1767   PrecisionConfig precision_config;
1768   precision_config.mutable_operand_precision()->Resize(
1769       2, PrecisionConfig::DEFAULT);
1770   auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1771       shape_2x4, param_a, param_b, dot_dnums, precision_config));
1772   auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1773       shape_3x4, param_b, param_c, dot_dnums, precision_config));
1774   builder.AddInstruction(
1775       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1776 
1777   // Run buffer assignment with alignment=1.
1778   auto module = CreateNewVerifiedModule();
1779   module->AddEntryComputation(builder.Build());
1780   auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1781 
1782   // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1783   EXPECT_EQ(5, assignment->Allocations().size());
1784 
1785   // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1786   // and each occupies different slices of that allocation.
1787   BufferAllocation::Slice slice_ab =
1788       assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1789   BufferAllocation::Slice slice_bc =
1790       assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1791   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1792   EXPECT_NE(slice_ab, slice_bc);
1793   EXPECT_EQ(32, slice_ab.size());
1794   EXPECT_EQ(48, slice_bc.size());
1795   EXPECT_EQ(80, slice_ab.allocation()->size());
1796   EXPECT_EQ(80, slice_bc.allocation()->size());
1797 
1798   // Re-run buffer assignment with alignment=64.
1799   assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1800   EXPECT_EQ(5, assignment->Allocations().size());
1801   slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1802   slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1803   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1804   EXPECT_NE(slice_ab, slice_bc);
1805   EXPECT_EQ(32, slice_ab.size());
1806   EXPECT_EQ(48, slice_bc.size());
1807   // Ensure the offsets and allocation size account for the alignment, without
1808   // assuming which buffer gets assigned first.
1809   if (slice_ab.offset() == 0) {
1810     EXPECT_EQ(64, slice_bc.offset());
1811     EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1812     EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1813   } else {
1814     EXPECT_EQ(64, slice_ab.offset());
1815     EXPECT_EQ(0, slice_bc.offset());
1816     EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1817     EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1818   }
1819 }
1820 
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1821 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1822   // paramscalar -(bc)- (mul) -- (add) -- (sub)
1823   //                     /        /        /
1824   // param0[100] -------/        /        /
1825   //                            /        /
1826   // param1[100] --------------/--------/
1827   auto builder = HloComputation::Builder(TestName());
1828   auto paramscalar =
1829       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1830   auto broadcast = builder.AddInstruction(
1831       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1832   auto param0 = builder.AddInstruction(
1833       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1834   auto param1 = builder.AddInstruction(
1835       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1836   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1837       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1838   auto add = builder.AddInstruction(
1839       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1840   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1841       f32vec100_, HloOpcode::kSubtract, add, param1));
1842   auto module = CreateNewVerifiedModule();
1843   module->AddEntryComputation(builder.Build());
1844 
1845   auto buffers = RunBufferAssignment(module.get());
1846 
1847   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1848   const std::vector<const HloValue*>& peak_buffers =
1849       mul_buffer.PeakMemoryLogicalBuffers();
1850   ASSERT_EQ(peak_buffers.size(), 1);
1851   EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1852 }
1853 
TEST_F(BufferAssignmentTest,PeakBuffers)1854 TEST_F(BufferAssignmentTest, PeakBuffers) {
1855   // Compute the peak liveness buffers of the following sequence:
1856   //
1857   //   %param = ...
1858   //   %log = log(%param)
1859   //   %rev = reverse(%log)
1860   //   %neg = neg(%param)
1861   //   %concat = concat(%rev, %neg)
1862   //   ROOT %root = slice(concat)
1863   //
1864   // In the temporary block, the set of live buffers at peak memory use should
1865   // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1866   auto builder = HloComputation::Builder(TestName());
1867   auto param = builder.AddInstruction(
1868       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1869   auto log = builder.AddInstruction(
1870       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1871   auto rev = builder.AddInstruction(
1872       HloInstruction::CreateReverse(f32vec100_, log, {0}));
1873   auto neg = builder.AddInstruction(
1874       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1875   const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1876   auto concat = builder.AddInstruction(
1877       HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1878   // Make the root tiny so no interior nodes can share its buffer.
1879   auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1880 
1881       ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1882 
1883   auto module = CreateNewVerifiedModule();
1884   module->AddEntryComputation(builder.Build());
1885 
1886   auto buffers = RunBufferAssignmentWithInstructionSequence(
1887       module.get(), {param, log, rev, neg, concat, root});
1888 
1889   // The temporary buffer should hold the 4 interior instructions.
1890   const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1891   EXPECT_FALSE(buffer.IsInputOrOutput());
1892   EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1893   ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1894 
1895   const std::vector<const HloValue*>& peak_buffers =
1896       buffer.PeakMemoryLogicalBuffers();
1897 
1898   // The peak live set should be concat and its inputs.
1899   ASSERT_EQ(peak_buffers.size(), 3);
1900   std::vector<const HloInstruction*> peak_instructions;
1901   for (const HloValue* logical_buffer : peak_buffers) {
1902     peak_instructions.push_back(logical_buffer->instruction());
1903   }
1904   EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1905 }
1906 
TEST_F(BufferAssignmentTest,InPlaceBuffer)1907 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1908   const char* hlo_text = R"(
1909 HloModule Module
1910 
1911 ENTRY main {
1912   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1913   constant.1 = f32[] constant(0)
1914   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1915   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1916   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1917   constant.2 = s32[] constant(128)
1918   add.5 = s32[] add(get-tuple-element.3, constant.2)
1919   constant.3 = s32[] constant(0)
1920   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
1921   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
1922   ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
1923 }
1924 )";
1925 
1926   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1927   HloInstruction* parameter =
1928       m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
1929   HloInstruction* dus1 =
1930       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
1931   HloInstruction* dus2 =
1932       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
1933 
1934   auto buffers = RunBufferAssignment(m.get());
1935 
1936   {
1937     const BufferAllocation& parameter_alloc =
1938         GetTopLevelAllocation(*buffers, parameter);
1939 
1940     const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
1941     EXPECT_EQ(parameter_alloc, dus1_alloc);
1942     const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
1943     EXPECT_EQ(parameter_alloc, dus2_alloc);
1944   }
1945 }
1946 
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1947 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1948   const char* hlo_text = R"(
1949 HloModule Module
1950 
1951 True {
1952   ROOT x.0.1 = f32[] parameter(0)
1953 }
1954 
1955 False {
1956   x.0.0 = f32[] parameter(0)
1957   ROOT copy.1 = f32[] copy(x.0.0)
1958 }
1959 
1960 ENTRY main {
1961   pred.1.0 = pred[] parameter(0)
1962   constant.1.1 = f32[] constant(56)
1963   copy.2 = f32[] copy(constant.1.1)
1964   constant.1.2 = f32[] constant(12)
1965   ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
1966       true_computation=True, false_computation=False
1967 }
1968 )";
1969 
1970   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1971   HloInstruction* constant_1 =
1972       m->entry_computation()->GetInstructionWithName("constant.1.1");
1973   HloInstruction* constant_2 =
1974       m->entry_computation()->GetInstructionWithName("constant.1.2");
1975 
1976   auto buffers = RunBufferAssignment(m.get());
1977 
1978   {
1979     const BufferAllocation& allocation_for_const_1 =
1980         GetTopLevelAllocation(*buffers, constant_1);
1981     EXPECT_TRUE(allocation_for_const_1.is_constant());
1982     for (const auto& buffer_offset_pair :
1983          allocation_for_const_1.assigned_buffers()) {
1984       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1985                 HloOpcode::kCopy);
1986       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1987                 HloOpcode::kConditional);
1988     }
1989   }
1990 
1991   {
1992     const BufferAllocation& allocation_for_const_2 =
1993         GetTopLevelAllocation(*buffers, constant_2);
1994     EXPECT_TRUE(allocation_for_const_2.is_constant());
1995     for (const auto& buffer_offset_pair :
1996          allocation_for_const_2.assigned_buffers()) {
1997       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1998                 HloOpcode::kCopy);
1999       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2000                 HloOpcode::kConditional);
2001     }
2002   }
2003 }
2004 
2005 class WhileBufferAssignmentTest : public HloTestBase {
2006  protected:
BuildWhileConditionComputation(const string & name)2007   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2008       const string& name) {
2009     auto builder = HloComputation::Builder(name);
2010     builder.AddInstruction(
2011         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2012     auto zero = builder.AddInstruction(
2013         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2014     auto ten = builder.AddInstruction(
2015         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2016     builder.AddInstruction(HloInstruction::CreateCompare(
2017         ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2018     return builder.Build();
2019   }
2020 
BuildWhileBodyComputation(const string & name)2021   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2022       const string& name) {
2023     auto builder = HloComputation::Builder(name);
2024     auto loop_state = builder.AddInstruction(
2025         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2026     auto input = builder.AddInstruction(
2027         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2028     auto weights = builder.AddInstruction(
2029         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2030     auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2031         data_shape_, HloOpcode::kMultiply, input, weights));
2032     builder.AddInstruction(
2033         HloInstruction::CreateTuple({input, weights, output}));
2034     return builder.Build();
2035   }
2036 
RunBufferAssignment(HloModule * module,int64 alignment=1)2037   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2038                                                         int64 alignment = 1) {
2039     HloSchedule schedule =
2040         ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
2041     return BufferAssigner::Run(
2042                module, absl::make_unique<SequentialHloOrdering>(schedule),
2043                ByteSizeOf,
2044                [alignment](LogicalBuffer::Color) { return alignment; },
2045                /*allocate_buffers_for_constants=*/true)
2046         .ConsumeValueOrDie();
2047   }
2048 
ByteSizeOf(const BufferValue & buffer)2049   static int64 ByteSizeOf(const BufferValue& buffer) {
2050     return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2051   }
2052 
2053   Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2054   Shape loop_state_shape_ =
2055       ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2056 };
2057 
RunCopyInsertion(HloModule * module)2058 static void RunCopyInsertion(HloModule* module) {
2059   CopyInsertion copy_insertion;
2060   EXPECT_IS_OK(copy_insertion.Run(module).status());
2061 }
2062 
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2063 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2064   auto module = CreateNewVerifiedModule();
2065   auto builder = HloComputation::Builder("entry");
2066 
2067   auto input0 = builder.AddInstruction(
2068       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2069   auto weights0 = builder.AddInstruction(
2070       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2071   auto weights1 = builder.AddInstruction(
2072       HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2073 
2074   auto zero = builder.AddInstruction(
2075       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2076   auto output0 = builder.AddInstruction(
2077       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2078   auto output1 = builder.AddInstruction(
2079       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2080 
2081   auto cond0 =
2082       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2083   auto body0 =
2084       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2085 
2086   auto tuple0 = builder.AddInstruction(
2087       HloInstruction::CreateTuple({input0, weights0, output0}));
2088   auto while0 = builder.AddInstruction(
2089       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2090 
2091   auto cond1 =
2092       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2093   auto body1 =
2094       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2095   auto input1 = builder.AddInstruction(
2096       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2097   auto tuple1 = builder.AddInstruction(
2098       HloInstruction::CreateTuple({input1, weights1, output1}));
2099   auto while1 = builder.AddInstruction(
2100       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2101 
2102   module->AddEntryComputation(builder.Build());
2103   RunCopyInsertion(module.get());
2104   auto assignment = RunBufferAssignment(module.get());
2105 
2106   // Verify 'input0' and read-only use while0{0} alias.
2107   EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
2108             assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
2109   // Verify 'weights0' and read-only use while0{1} alias.
2110   EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
2111             assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
2112   // Verify 'while0{2}' and read-only use while1{0} alias.
2113   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2114             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2115   // Verify 'weights1' and read-only use while1{1} alias.
2116   EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
2117             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2118 }
2119 
2120 // Tests that two colocated buffer sets are not merged if an entry parameter
2121 // buffer belongs to either of the colocation sets (b/73267882).
2122 //
2123 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2124 //
2125 // %while.0 body just forwards the init value, so the loop carried variable
2126 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2127 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2128   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2129 
2130   const char* module_str = R"(
2131 HloModule test_module
2132 
2133 %cond.v0 {
2134   %param = s32[] parameter(0)
2135   ROOT %constant = pred[] constant(true)
2136 }
2137 
2138 %cond.v1 {
2139   %param.0 = s32[] parameter(0)
2140   ROOT %constant.0 = pred[] constant(true)
2141 }
2142 
2143 %body.v0 {
2144   ROOT %param.1 = s32[] parameter(0)
2145 }
2146 
2147 %body.v1 {
2148   %param.2 = s32[] parameter(0)
2149   ROOT add = s32[] add(%param.2, %param.2)
2150 }
2151 
2152 ENTRY %test_module {
2153   %param.3 = s32[] parameter(0)
2154   %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2155   %mul = s32[] multiply(%while.0, %while.0)
2156   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2157   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2158 })";
2159 
2160   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2161 
2162   // Run CopyInsertion and check if the graph constructed above doesn't need
2163   // any copies inserted for BufferAssignment to run.
2164   int64 instruction_count = m->instruction_count();
2165   CopyInsertion copy_insertion;
2166   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2167   ASSERT_EQ(instruction_count, m->instruction_count());
2168 
2169   // Get the instructions in the module.
2170   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2171   const HloInstruction* param =
2172       m->entry_computation()->parameter_instruction(0);
2173   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2174   const HloInstruction* while1 = bcast->operand(0);
2175   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2176   const HloInstruction* while0 = while1->operand(0)->operand(0);
2177   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2178 
2179   // Run buffer assignment.
2180   auto assignment = RunBufferAssignment(m.get());
2181   TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2182                           assignment->GetUniqueSlice(param, {}));
2183   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2184                           assignment->GetUniqueSlice(while0, {}));
2185   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2186                           assignment->GetUniqueSlice(while1, {}));
2187 
2188   // The parameter slice is part of the while0's colocation set (init value),
2189   // but not merged into the while1's colocation set.
2190   EXPECT_EQ(slice_param, slice_while0);
2191   EXPECT_NE(slice_param, slice_while1);
2192 }
2193 
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2194 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2195   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2196 
2197   const char* module_str = R"(
2198 HloModule test_module
2199 
2200 %cond.v0 {
2201   %param = s32[] parameter(0)
2202   ROOT %constant = pred[] constant(true)
2203 }
2204 
2205 %cond.v1 {
2206   %param.0 = s32[] parameter(0)
2207   ROOT %constant.0 = pred[] constant(true)
2208 }
2209 
2210 %body.v0 {
2211   ROOT %param.1 = s32[] parameter(0)
2212 }
2213 
2214 %body.v1 {
2215   %param.2 = s32[] parameter(0)
2216   ROOT add = s32[] add(%param.2, %param.2)
2217 }
2218 
2219 ENTRY %test_module {
2220   %constant.42 = s32[] constant(42)
2221   %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2222   %mul = s32[] multiply(%while.0, %while.0)
2223   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2224   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2225 })";
2226 
2227   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2228 
2229   // Run CopyInsertion and check if the graph constructed above doesn't need
2230   // any copies inserted for BufferAssignment to run.
2231   int64 instruction_count = m->instruction_count();
2232   CopyInsertion copy_insertion;
2233   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2234   ASSERT_EQ(instruction_count, m->instruction_count());
2235 
2236   // Get the instructions in the module.
2237   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2238   const HloInstruction* constant =
2239       m->entry_computation()->GetInstructionWithName("constant.42");
2240   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2241   const HloInstruction* while1 = bcast->operand(0);
2242   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2243   const HloInstruction* while0 = while1->operand(0)->operand(0);
2244   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2245 
2246   // Run buffer assignment.
2247   auto assignment = RunBufferAssignment(m.get());
2248   TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2249                           assignment->GetUniqueSlice(constant, {}));
2250   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2251                           assignment->GetUniqueSlice(while0, {}));
2252   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2253                           assignment->GetUniqueSlice(while1, {}));
2254 
2255   // The constant slice is part of the while0's colocation set (init value), but
2256   // not merged into the while1's colocation set.
2257   EXPECT_EQ(slice_constant, slice_while0);
2258   EXPECT_NE(slice_constant, slice_while1);
2259 }
2260 
2261 // Tests that the colocated buffers for while instructions are properly assigned
2262 // during buffer assignment such that the result tuple elements are not assigned
2263 // to the same buffer.
2264 //
2265 // %infeed --> %while.0 --> %while.1 --+
2266 //                                     +-- %tuple
2267 //   %zero -->   %add   --> %while.2 --+
2268 //
2269 // Execution Order:
2270 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2271 //
2272 // The HLO computation used in this test requires specific ordering to expose
2273 // the bug (b/72496031). During buffer assignment, the visitation order of
2274 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2275 // assignment was coalescing the colocated buffers for all 3 while instructions,
2276 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2277 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2278   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2279 
2280   // Builds a condition computation: x -> x < 4
2281   auto build_cond = [&]() {
2282     auto builder = HloComputation::Builder("cond");
2283     auto const4 = builder.AddInstruction(
2284         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2285     auto param =
2286         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2287     builder.AddInstruction(
2288         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2289                                       const4, ComparisonDirection::kLt));
2290     return builder.Build();
2291   };
2292 
2293   // Builds a body computation: x -> x + 9
2294   auto build_body = [&]() {
2295     auto builder = HloComputation::Builder("body");
2296     auto const9 = builder.AddInstruction(
2297         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2298     auto param =
2299         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2300     builder.AddInstruction(
2301         HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2302     return builder.Build();
2303   };
2304 
2305   // Build the entry computation as described in the comment above.
2306   auto module = CreateNewVerifiedModule();
2307   auto builder = HloComputation::Builder("entry");
2308 
2309   auto token = builder.AddInstruction(HloInstruction::CreateToken());
2310   auto infeed =
2311       builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2312   auto infeed_data = builder.AddInstruction(
2313       HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2314   auto cond0 = module->AddEmbeddedComputation(build_cond());
2315   auto body0 = module->AddEmbeddedComputation(build_body());
2316   auto while0 = builder.AddInstruction(
2317       HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2318 
2319   auto cond1 = module->AddEmbeddedComputation(build_cond());
2320   auto body1 = module->AddEmbeddedComputation(build_body());
2321   auto while1 = builder.AddInstruction(
2322       HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2323 
2324   auto zero = builder.AddInstruction(
2325       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2326   auto add = builder.AddInstruction(
2327       HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2328   auto cond2 = module->AddEmbeddedComputation(build_cond());
2329   auto body2 = module->AddEmbeddedComputation(build_body());
2330   auto while2 = builder.AddInstruction(
2331       HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2332 
2333   auto tuple =
2334       builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2335   module->AddEntryComputation(builder.Build());
2336 
2337   // Run CopyInsertion and check if the graph constructed above doesn't need
2338   // any copies inserted for BufferAssignment to run.
2339   int64 instruction_count = module->instruction_count();
2340   CopyInsertion copy_insertion;
2341   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2342   ASSERT_EQ(instruction_count, module->instruction_count());
2343 
2344   // Create a sequential order among all the instructions in the entry
2345   // computation, since the issue this test stresses depends on the order the
2346   // nodes are traversed during BufferAssignment.
2347   TF_ASSERT_OK_AND_ASSIGN(
2348       HloSchedule schedule,
2349       ScheduleModule(module.get(), [](const BufferValue& buffer) {
2350         return ShapeUtil::ByteSizeOf(buffer.shape(),
2351                                      /*pointer_size=*/sizeof(void*));
2352       }));
2353   schedule.set_sequence(
2354       module->entry_computation(),
2355       {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2356   TF_ASSERT_OK(schedule.Verify());
2357 
2358   TF_ASSERT_OK_AND_ASSIGN(
2359       auto assignment,
2360       BufferAssigner::Run(
2361           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2362           backend().compiler()->BufferSizeBytesFunction(),
2363           [](LogicalBuffer::Color) { return 1; },
2364           /*allocate_buffers_for_constants=*/true));
2365 
2366   // The result tuple elements must be assigned with different buffers.
2367   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2368   TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2369   EXPECT_NE(slice0, slice1);
2370 
2371   // while0 and while1 result buffers must be equal to slice1.
2372   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2373                           assignment->GetUniqueSlice(while0, {}));
2374   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2375                           assignment->GetUniqueSlice(while1, {}));
2376   EXPECT_EQ(slice1, slice_while0);
2377   EXPECT_EQ(slice1, slice_while1);
2378 
2379   // while2 result buffer must be equal to slice0.
2380   TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2381                           assignment->GetUniqueSlice(while2, {}));
2382   EXPECT_EQ(slice0, slice_while2);
2383 }
2384 
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2385 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2386   auto module = CreateNewVerifiedModule();
2387   auto builder = HloComputation::Builder("entry");
2388 
2389   auto input0 = builder.AddInstruction(
2390       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2391   auto weights0 = builder.AddInstruction(
2392       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2393 
2394   auto zero = builder.AddInstruction(
2395       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2396   auto output0 = builder.AddInstruction(
2397       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2398 
2399   auto cond0 =
2400       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2401   auto body0 =
2402       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2403 
2404   auto tuple0 = builder.AddInstruction(
2405       HloInstruction::CreateTuple({input0, weights0, output0}));
2406   auto while0 = builder.AddInstruction(
2407       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2408 
2409   auto cond1 =
2410       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2411   auto body1 =
2412       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2413 
2414   auto while1 = builder.AddInstruction(
2415       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2416 
2417   module->AddEntryComputation(builder.Build());
2418   RunCopyInsertion(module.get());
2419   auto assignment = RunBufferAssignment(module.get());
2420 
2421   // while0 and while1 buffers should be completely aligned.
2422   EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2423             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2424   EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2425             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2426   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2427             assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2428 }
2429 
TEST_F(BufferAssignmentTest,TwoCalls)2430 TEST_F(BufferAssignmentTest, TwoCalls) {
2431   auto module = CreateNewVerifiedModule();
2432   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2433   HloComputation* sub_computation;
2434   {
2435     auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2436     auto param = builder.AddInstruction(
2437         HloInstruction::CreateParameter(0, r0f32, "param"));
2438     auto constant1 = builder.AddInstruction(
2439         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2440     auto add = builder.AddInstruction(
2441         HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2442     sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2443   }
2444   auto builder = HloComputation::Builder(TestName());
2445   auto constant2 = builder.AddInstruction(
2446       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2447   auto constant3 = builder.AddInstruction(
2448       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2449   auto call1 = builder.AddInstruction(
2450       HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2451   auto call2 = builder.AddInstruction(
2452       HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2453   auto add1 = builder.AddInstruction(
2454       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2455   auto add2 = builder.AddInstruction(
2456       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2457   module->AddEntryComputation(builder.Build(add2));
2458 
2459   {
2460     FlattenCallGraph flatten;
2461     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2462     EXPECT_TRUE(result);
2463     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2464   }
2465 
2466   RunCopyInsertion(module.get());
2467   auto assignment = RunBufferAssignment(module.get());
2468 
2469   EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2470 }
2471 
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2472 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2473   const char* hlo_text = R"(
2474 HloModule CallParamCoAllocation
2475 
2476 Callee {
2477   param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2478   param1 = s32[20] parameter(1)
2479   ROOT constant = f32[] constant(1)
2480 }
2481 
2482 ENTRY Main {
2483   entry_param0 = f32[100] parameter(0)
2484   entry_param1 = s32[20]  parameter(1)
2485   custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2486   call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2487   ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2488 }
2489 )";
2490 
2491   HloModuleConfig config;
2492   config.set_debug_options(GetDebugOptionsFromFlags());
2493   TF_ASSERT_OK_AND_ASSIGN(auto m,
2494                           ParseAndReturnVerifiedModule(hlo_text, config));
2495 
2496   auto buffers = RunBufferAssignment(m.get());
2497 
2498   HloComputation* main = m->entry_computation();
2499   HloComputation* callee = m->GetComputationWithName("Callee");
2500   EXPECT_NE(callee, nullptr);
2501 
2502   HloInstruction* param0 = callee->parameter_instruction(0);
2503   HloInstruction* param1 = callee->parameter_instruction(1);
2504 
2505   HloInstruction* entry_param0 = main->parameter_instruction(0);
2506   HloInstruction* entry_param1 = main->parameter_instruction(1);
2507   HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2508 
2509   EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2510             GetAllocation(*buffers, param0, {0}));
2511   EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2512             GetAllocation(*buffers, param1, {}));
2513 
2514   EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2515             GetAllocation(*buffers, param0, {1}));
2516   EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2517             GetAllocation(*buffers, param0, {1, 0}));
2518   EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2519             GetAllocation(*buffers, param0, {1, 1}));
2520 }
2521 
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2522 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2523   absl::string_view module_str = R"(
2524 HloModule test_module
2525 
2526 ENTRY %test_module {
2527   %param.0 = s32[1024]{0} parameter(0)
2528   %param.1 = s32[1024]{0} parameter(1)
2529   %mul = s32[1024]{0} multiply(%param.0, %param.1)
2530   %add = s32[1024]{0} add(%mul, %param.0)
2531   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2532 })";
2533 
2534   absl::string_view reference_str =
2535       R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2536 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2537 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2538 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2539 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2540 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2541 )";
2542 
2543   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2544   HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2545   HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2546   HloInstruction* const mul = FindInstruction(m.get(), "mul");
2547   HloInstruction* const add = FindInstruction(m.get(), "add");
2548   HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2549   // Run buffer assignment.
2550   auto assignment = RunBufferAssignmentWithInstructionSequence(
2551       m.get(), {param0, param1, mul, add, bcast});
2552   const std::string buffer_info_str = assignment->BufferInfoString();
2553 
2554   EXPECT_EQ(buffer_info_str, reference_str);
2555 }
2556 
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2557 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2558   auto module = CreateNewVerifiedModule();
2559   auto builder = HloComputation::Builder(TestName());
2560 
2561   auto zero = builder.AddInstruction(
2562       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2563   auto one = builder.AddInstruction(
2564       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2565 
2566   auto input0 = builder.AddInstruction(
2567       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2568   auto weights0 = builder.AddInstruction(
2569       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2570   auto output0 = builder.AddInstruction(
2571       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2572 
2573   auto input1 = builder.AddInstruction(
2574       HloInstruction::CreateParameter(2, data_shape_, "input1"));
2575   auto weights1 = builder.AddInstruction(
2576       HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2577   auto output1 = builder.AddInstruction(
2578       HloInstruction::CreateBroadcast(data_shape_, one, {}));
2579 
2580   auto cond =
2581       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2582   auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2583 
2584   auto tuple0 = builder.AddInstruction(
2585       HloInstruction::CreateTuple({input0, weights0, output0}));
2586   auto tuple1 = builder.AddInstruction(
2587       HloInstruction::CreateTuple({input1, weights1, output1}));
2588 
2589   auto while0 = builder.AddInstruction(
2590       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2591   auto while1 = builder.AddInstruction(
2592       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2593 
2594   auto gte0 = builder.AddInstruction(
2595       HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2596   auto gte1 = builder.AddInstruction(
2597       HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2598   auto root_add = builder.AddInstruction(
2599       HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2600 
2601   module->AddEntryComputation(builder.Build());
2602 
2603   {
2604     FlattenCallGraph flatten;
2605     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2606     EXPECT_TRUE(result);
2607   }
2608 
2609   RunCopyInsertion(module.get());
2610 
2611   HloSchedule schedule =
2612       ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2613 
2614   // To trigger b/38494731, we want a specific Hlo schedule for the
2615   // root computation, so we overwrite that entry with a manually
2616   // crafted sequence.
2617   schedule.set_sequence(
2618       module->entry_computation(),
2619       {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2620        input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2621        gte0, gte1, root_add});
2622 
2623   // If this ASSERT fails, we constructed a bogus sequence above and this test
2624   // itself is buggy.
2625   TF_ASSERT_OK(schedule.Verify());
2626 
2627   auto assignment =
2628       BufferAssigner::Run(
2629           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2630           ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2631           /*allocate_buffers_for_constants=*/true)
2632           .ConsumeValueOrDie();
2633 
2634   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2635 }
2636 
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2637 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2638   auto module = CreateNewVerifiedModule();
2639   auto builder = HloComputation::Builder("entry");
2640 
2641   auto input0 = builder.AddInstruction(
2642       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2643   auto weights0 = builder.AddInstruction(
2644       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2645 
2646   auto zero = builder.AddInstruction(
2647       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2648   auto output0 = builder.AddInstruction(
2649       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2650   auto output1 = builder.AddInstruction(
2651       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2652 
2653   auto cond0 =
2654       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2655   auto body0 =
2656       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2657 
2658   auto tuple0 = builder.AddInstruction(
2659       HloInstruction::CreateTuple({input0, weights0, output0}));
2660   auto while0 = builder.AddInstruction(
2661       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2662 
2663   // Get output of 'while0' and feed as input to 'while1'.
2664   auto while0_out = builder.AddInstruction(
2665       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2666 
2667   auto cond1 =
2668       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2669   auto body1 =
2670       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2671 
2672   auto tuple1 = builder.AddInstruction(
2673       HloInstruction::CreateTuple({while0_out, weights0, output1}));
2674   auto while1 = builder.AddInstruction(
2675       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2676 
2677   // Get output of 'while1' so that it is live out of computation.
2678   auto while1_out = builder.AddInstruction(
2679       HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2680 
2681   module->AddEntryComputation(builder.Build());
2682   RunCopyInsertion(module.get());
2683   auto assignment = RunBufferAssignment(module.get());
2684   // Get BufferAllocation for root instruction.
2685   auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2686                          .ConsumeValueOrDie()
2687                          .allocation();
2688   // Test that root instruction allocation is live out.
2689   EXPECT_TRUE(root_alloc->maybe_live_out());
2690   // Test that root instruction allocation is not an entry parameter.
2691   EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2692 }
2693 
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2694 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2695   const char* const hlo_string = R"(
2696 HloModule test
2697 
2698 while_body {
2699   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2700   constant.1 = f32[] constant(0)
2701   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2702   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2703   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2704   constant.2 = s32[] constant(128)
2705   add.5 = s32[] add(get-tuple-element.3, constant.2)
2706   constant.3 = s32[] constant(0)
2707   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2708   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2709   ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2710 }
2711 
2712 while_condition {
2713   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2714   get-tuple-element = s32[] get-tuple-element(state), index=0
2715   get-tuple-element.1 = s32[] constant(3)
2716   ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2717 }
2718 
2719 ENTRY entry_computation {
2720   constant.7 = s32[] constant(0)
2721   copy.1 = s32[] copy(constant.7)
2722   constant.6 = f32[] constant(0)
2723   broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2724   tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2725   while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2726   ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2727 }
2728 
2729 )";
2730   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2731   auto module = module_or_status.ConsumeValueOrDie();
2732 
2733   RunCopyInsertion(module.get());
2734   auto assignment = RunBufferAssignment(module.get());
2735   // Get BufferAllocation for root instruction.
2736   auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2737   auto dus9_alloc_slice =
2738       assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2739   auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2740   auto dus5_alloc_slice =
2741       assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2742   // Test that the two dynamic-update-slice ops share the same allocation slice.
2743   EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2744   EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2745 }
2746 }  // namespace
2747 }  // namespace xla
2748