1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/buffer_value.h"
28 #include "tensorflow/compiler/xla/service/call_graph.h"
29 #include "tensorflow/compiler/xla/service/copy_insertion.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
37 #include "tensorflow/compiler/xla/service/hlo_parser.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/test.h"
41 #include "tensorflow/compiler/xla/test_helpers.h"
42 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/platform/macros.h"
47 
48 namespace xla {
49 namespace {
50 
51 using ::testing::UnorderedElementsAre;
52 
53 // DFS visitor that collects the instructions referenced by a computation
54 // without descending into nested computations, i.e., only from the operands.
55 class InstructionListVisitor : public DfsHloVisitorWithDefault {
56  public:
InstructionListVisitor(const HloInstruction * root)57   explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
58 
DefaultAction(HloInstruction * hlo)59   Status DefaultAction(HloInstruction* hlo) override {
60     // For each instruction, just push it on the list after walking the
61     // operands.
62     instructions_.push_back(hlo);
63     VLOG(0) << "List instruction " << hlo->ToString();
64     return Status::OK();
65   }
66 
GetInstructions()67   std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
68 
69  private:
70   // The instruction root of the computation.
71   const HloInstruction* root_;
72 
73   // The full set of instructions found (may be duplicates, e.g., kParameter).
74   std::vector<const HloInstruction*> instructions_;
75 
76   TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
77 };
78 
GetInstructions(HloInstruction * root)79 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
80   InstructionListVisitor main_list(root);
81   TF_CHECK_OK(root->Accept(&main_list));
82   return main_list.GetInstructions();
83 }
84 
85 class BufferAssignmentTest : public HloTestBase {
86  protected:
~BufferAssignmentTest()87   ~BufferAssignmentTest() override {}
88 
RunBufferAssignment(HloModule * module,int64 alignment=1)89   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
90                                                         int64 alignment = 1) {
91     return BufferAssigner::Run(
92                module, absl::make_unique<DependencyHloOrdering>(module),
93                backend().compiler()->BufferSizeBytesFunction(),
94                [alignment](LogicalBuffer::Color) { return alignment; },
95                /*allow_input_output_aliasing=*/false,
96                /*allocate_buffers_for_constants=*/true)
97         .ConsumeValueOrDie();
98   }
99 
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64 alignment=1)100   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
101       HloModule* module, int64 alignment = 1) {
102     return BufferAssigner::Run(
103                module, absl::make_unique<DependencyHloOrdering>(module),
104                backend().compiler()->BufferSizeBytesFunction(),
105                [alignment](LogicalBuffer::Color) { return alignment; },
106                /*allow_input_output_aliasing=*/false,
107                /*allocate_buffers_for_constants=*/false)
108         .ConsumeValueOrDie();
109   }
110 
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64 alignment=1)111   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
112       HloModule* module, int64 alignment = 1) {
113     auto reuse_checker = [](const BufferAssignment& assignment,
114                             const BufferAllocation& alloc,
115                             const LogicalBuffer& buffer) {
116       return (buffer.instruction()->opcode() != HloOpcode::kAdd);
117     };
118     return BufferAssigner::Run(
119                module, absl::make_unique<DependencyHloOrdering>(module),
120                backend().compiler()->BufferSizeBytesFunction(),
121                [alignment](LogicalBuffer::Color) { return alignment; },
122                /*allow_input_output_aliasing=*/false,
123                /*allocate_buffers_for_constants=*/false,
124                /*colorer=*/BufferLiveness::DefaultColorer(),
125                /*reuse_checker=*/reuse_checker)
126         .ConsumeValueOrDie();
127   }
128 
RunColoredBufferAssignment(HloModule * module,BufferLiveness::Colorer colorer,int64 alignment=1)129   std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
130       HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
131     return BufferAssigner::Run(
132                module, absl::make_unique<DependencyHloOrdering>(module),
133                backend().compiler()->BufferSizeBytesFunction(),
134                [alignment](LogicalBuffer::Color) { return alignment; },
135                /*allow_input_output_aliasing=*/false,
136                /*allocate_buffers_for_constants=*/true, std::move(colorer))
137         .ConsumeValueOrDie();
138   }
139 
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64 alignment=1)140   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
141       HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
142       int64 alignment = 1) {
143     HloSchedule schedule(module);
144     schedule.set_sequence(module->entry_computation(), instruction_sequence);
145     return BufferAssigner::Run(
146                module, absl::make_unique<SequentialHloOrdering>(schedule),
147                backend().compiler()->BufferSizeBytesFunction(),
148                [alignment](LogicalBuffer::Color) { return alignment; },
149                /*allow_input_output_aliasing=*/false,
150                /*allocate_buffers_for_constants=*/true)
151         .ConsumeValueOrDie();
152   }
153 
154   // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)155   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
156     auto builder = HloComputation::Builder(name);
157     auto param =
158         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
159     auto value = builder.AddInstruction(
160         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
161     builder.AddInstruction(
162         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
163     return builder.Build();
164   }
165 
BuildReduceComputation(const string & name)166   std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
167     auto builder = HloComputation::Builder(name);
168     auto param =
169         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
170     auto param2 =
171         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
172     builder.AddInstruction(
173         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
174     return builder.Build();
175   }
176 
177   // Builds a simple compare-to-limit (x < 4) computation for a While.
178   //
179   // condition:
180   //   const4[s32] -----------------------------------\
181   //                                                   \
182   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
183   //
BuildWhileConditionComputation(const string & name)184   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
185       const string& name) {
186     auto builder = HloComputation::Builder(name);
187     auto const4 = builder.AddInstruction(
188         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
189     auto param = builder.AddInstruction(
190         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
191     auto index = builder.AddInstruction(
192         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
193     builder.AddInstruction(
194         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
195                                       const4, ComparisonDirection::kLt));
196     return builder.Build();
197   }
198 
199   // Builds a simple body computation for a While.
200   //
201   // body:
202   //   constv[f32[4]] --------------------------------------\
203   //                                                         \
204   //                           /--- get-tuple-elementv[1] --- addv ---\
205   //   param[(s32,f32[4])] ---|                                    tuple
206   //                           \--- get-tuple-elementc[0] --- addc ---/
207   //                                                         /
208   //   const1[s32] -----------------------------------------/
209   //
BuildWhileBodyComputation(const string & name)210   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
211       const string& name) {
212     auto builder = HloComputation::Builder(name);
213     auto const1 = builder.AddInstruction(
214         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
215     auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
216         LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
217     auto param = builder.AddInstruction(
218         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
219     auto indexc = builder.AddInstruction(
220         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
221     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
222         indexc->shape(), HloOpcode::kAdd, indexc, const1));
223     auto indexv = builder.AddInstruction(
224         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
225     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
226         constv->shape(), HloOpcode::kAdd, indexv, constv));
227     builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
228     return builder.Build();
229   }
230 
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)231   std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
232       HloOpcode opcode, const string& name) {
233     auto builder = HloComputation::Builder(name);
234     auto param =
235         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
236     builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
237     return builder.Build();
238   }
239 
240   // Verifies that the given instruction hlo has a valid input buffer assigned,
241   // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)242   const BufferAllocation& GetAssignedInputAllocation(
243       const BufferAssignment& buffers, HloInstruction* hlo) {
244     LOG(INFO) << "Checking input: " << hlo->ToString();
245     const BufferAllocation& buffer =
246         *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
247     EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
248     return buffer;
249   }
250 
251   // Verifies that the given instruction hlo has a valid output buffer
252   // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)253   const BufferAllocation& GetAssignedOutputAllocation(
254       const BufferAssignment& buffers, HloInstruction* hlo) {
255     LOG(INFO) << "Checking output: " << hlo->ToString();
256     const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
257     return buffer;
258   }
259 
260   // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)261   const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
262                                         const HloInstruction* hlo,
263                                         const ShapeIndex& index) {
264     return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
265   }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)266   const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
267                                                 const HloInstruction* hlo) {
268     return *buffers.GetUniqueTopLevelSlice(hlo)
269                 .ConsumeValueOrDie()
270                 .allocation();
271   }
272 
273   // Verifies that all instructions in the given instruction list except
274   // kConstant have assigned buffers, and returns their total size. If min_index
275   // and max_index are not nullptr, the minimum and maximum buffer indices in
276   // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)277   int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
278                         const BufferAssignment& buffers) {
279     // Verifies all instructions have buffers, and gets the index ranges.
280     for (const HloInstruction* hlo : instructions) {
281       if (!buffers.HasTopLevelAllocation(hlo)) {
282         // If `hlo` has no assigned buffer, it is either a constant or a nested
283         // parameter.
284         EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
285                     HloOpcode::kParameter == hlo->opcode());
286         continue;
287       }
288     }
289 
290     // Gets the total size of all buffers assigned.
291     int64 total_size = 0;
292     for (auto& allocation : buffers.Allocations()) {
293       total_size += allocation.size();
294     }
295     return total_size;
296   }
297 
298   // Shapes for use in the examples.
299   Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
300   Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
301   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
302   Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
303   Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
304   Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
305   Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
306   Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
307 };
308 
309 // Returns true if the buffers assigned to instructions in "a" are distinct
310 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)311 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
312                             const std::vector<const HloInstruction*>& b,
313                             const BufferAssignment& assignment) {
314   absl::flat_hash_set<BufferAllocation::Slice> a_slices;
315   for (const HloInstruction* instruction : a) {
316     if (assignment.HasTopLevelAllocation(instruction)) {
317       a_slices.insert(
318           assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
319     }
320   }
321 
322   for (const HloInstruction* instruction : b) {
323     if (assignment.HasTopLevelAllocation(instruction)) {
324       if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
325                                 .ConsumeValueOrDie())) {
326         return false;
327       }
328     }
329   }
330   return true;
331 }
332 
333 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)334 TEST_F(BufferAssignmentTest, ScalarConstant) {
335   auto builder = HloComputation::Builder(TestName());
336   auto const0 = builder.AddInstruction(
337       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
338   auto module = CreateNewVerifiedModule();
339   module->AddEntryComputation(builder.Build());
340 
341   {
342     auto buffers = RunBufferAssignment(module.get());
343     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
344   }
345 
346   {
347     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
348     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
349   }
350 }
351 
TEST_F(BufferAssignmentTest,BufferForConst)352 TEST_F(BufferAssignmentTest, BufferForConst) {
353   // Addition of two vector constants: checks that internal constant nodes have
354   // no buffers assigned, and their consumer has a buffer.
355   auto builder = HloComputation::Builder(TestName());
356   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
357       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
358   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
359       LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
360   auto add = builder.AddInstruction(
361       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
362   auto module = CreateNewVerifiedModule();
363   module->AddEntryComputation(builder.Build());
364 
365   {
366     auto buffers = RunBufferAssignment(module.get());
367     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
368     EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
369     GetAssignedOutputAllocation(*buffers, add);
370   }
371   {
372     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
373     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
374     EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
375     GetAssignedOutputAllocation(*buffers, add);
376   }
377 }
378 
TEST_F(BufferAssignmentTest,HasAllocationAt)379 TEST_F(BufferAssignmentTest, HasAllocationAt) {
380   // Create a tuple with non-const and const elements and check that
381   // HasAllocationAt works correctly.
382   auto builder = HloComputation::Builder(TestName());
383   auto param0 = builder.AddInstruction(
384       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
385   auto constant = builder.AddInstruction(
386       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
387   auto negate = builder.AddInstruction(
388       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
389   auto tuple = builder.AddInstruction(
390       HloInstruction::CreateTuple({negate, param0, constant}));
391   auto module = CreateNewVerifiedModule();
392   module->AddEntryComputation(builder.Build());
393 
394   auto buffers = RunBufferAssignment(module.get());
395   // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
396   // reports for the instruction directly.
397   EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
398             buffers->HasAllocationAt(tuple, /*index=*/{}));
399   EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
400             buffers->HasAllocationAt(tuple, /*index=*/{0}));
401   EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
402             buffers->HasAllocationAt(tuple, /*index=*/{1}));
403   EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
404             buffers->HasAllocationAt(tuple, /*index=*/{2}));
405 }
406 
TEST_F(BufferAssignmentTest,BufferForOutputConst)407 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
408   // This computation copies a constant to output.
409   auto builder = HloComputation::Builder(TestName());
410   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
411       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
412   auto copy = builder.AddInstruction(
413       HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
414   auto module = CreateNewVerifiedModule();
415   module->AddEntryComputation(builder.Build());
416 
417   auto buffers = RunBufferAssignment(module.get());
418   // The copy node now has an output buffer.
419   GetAssignedOutputAllocation(*buffers, copy);
420 }
421 
TEST_F(BufferAssignmentTest,Basic)422 TEST_F(BufferAssignmentTest, Basic) {
423   // paramscalar ------- (mul) -- (add) -- (sub)
424   //                     /        /        /
425   // param0[100] -------/        /        /
426   //                            /        /
427   // param1[100] --------------/--------/
428   auto builder = HloComputation::Builder(TestName());
429   auto paramscalar =
430       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
431   auto broadcast = builder.AddInstruction(
432       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
433   auto param0 = builder.AddInstruction(
434       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
435   auto param1 = builder.AddInstruction(
436       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
437   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
438       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
439   auto add = builder.AddInstruction(
440       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
441   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
442       f32vec100_, HloOpcode::kSubtract, add, param1));
443   auto module = CreateNewVerifiedModule();
444   module->AddEntryComputation(builder.Build());
445 
446   auto buffers = RunBufferAssignment(module.get());
447 
448   // Distinct input buffers were assigned for parameters.
449   BufferAllocation paramscalar_buffer =
450       GetAssignedInputAllocation(*buffers, paramscalar);
451   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
452   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
453   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
454   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
455   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
456 
457   // The mul node has a valid buffer assigned, doesn't share with input.
458   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
459   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
460 
461   // The add node can reuse the mul node's buffer.
462   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
463   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
464 
465   // The sub node has a valid output buffer assigned.
466   GetAssignedOutputAllocation(*buffers, sub);
467 }
468 
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)469 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
470   // If an input buffer and output buffer aliases, the input buffer can be
471   // reused for other intermediate results.
472   //
473   // param0[100] ----- (neg1) -- (neg2)
474   //    |                           |
475   //    + -------- Aliased ---------+
476 
477   auto builder = HloComputation::Builder(TestName());
478 
479   auto param = builder.AddInstruction(
480       HloInstruction::CreateParameter(0, f32vec100_, "p0"));
481   auto neg_1 = builder.AddInstruction(
482       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
483   auto neg_2 = builder.AddInstruction(
484       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
485 
486   auto module = CreateNewVerifiedModule();
487   module->AddEntryComputation(builder.Build());
488 
489   TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias(
490       {}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
491 
492   auto buffers = RunBufferAssignment(module.get());
493 
494   BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
495   BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
496   BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
497 
498   // Everything use one buffer.
499   EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
500   EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
501 }
502 
TEST_F(BufferAssignmentTest,AddCannotReuse)503 TEST_F(BufferAssignmentTest, AddCannotReuse) {
504   // Pass in a special rule to indicate that "add" cannot reuse any buffer.
505   //
506   // paramscalar ------- (mul) -- (add) -- (sub)
507   //                     /        /        /
508   // param0[100] -------/        /        /
509   //                            /        /
510   // param1[100] --------------/--------/
511   auto builder = HloComputation::Builder(TestName());
512   auto paramscalar =
513       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
514   auto broadcast = builder.AddInstruction(
515       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
516   auto param0 = builder.AddInstruction(
517       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
518   auto param1 = builder.AddInstruction(
519       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
520   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
521       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
522   auto add = builder.AddInstruction(
523       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
524   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
525       f32vec100_, HloOpcode::kSubtract, add, param1));
526   auto module = CreateNewVerifiedModule();
527   module->AddEntryComputation(builder.Build());
528 
529   auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
530 
531   // Distinct input buffers were assigned for parameters.
532   BufferAllocation paramscalar_buffer =
533       GetAssignedInputAllocation(*buffers, paramscalar);
534   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
535   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
536   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
537   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
538   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
539 
540   // The mul node has a valid buffer assigned, doesn't share with input.
541   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
542   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
543 
544   // The add node cannot reuse the mul node's buffer since we told buffer
545   // assignment so.
546   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
547   EXPECT_NE(add_buffer.index(), mul_buffer.index());
548 
549   // The sub node has a valid output buffer assigned.
550   GetAssignedOutputAllocation(*buffers, sub);
551 }
552 
TEST_F(BufferAssignmentTest,BasicUniquelyColored)553 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
554   // paramscalar ------- (mul) -- (add) -- (sub)
555   //                     /        /        /
556   // param0[100] -------/        /        /
557   //                            /        /
558   // param1[100] --------------/--------/
559   // The output of each op is colored with a different color, so we can not
560   // share anything.
561   auto builder = HloComputation::Builder(TestName());
562   auto paramscalar =
563       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
564   auto broadcast = builder.AddInstruction(
565       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
566   auto param0 = builder.AddInstruction(
567       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
568   auto param1 = builder.AddInstruction(
569       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
570   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
571       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
572   auto add = builder.AddInstruction(
573       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
574   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
575       f32vec100_, HloOpcode::kSubtract, add, param1));
576   auto module = CreateNewVerifiedModule();
577   module->AddEntryComputation(builder.Build());
578 
579   auto colorer = [](const BufferLiveness& buffer_liveness) {
580     int color = 0;
581 
582     for (LogicalBuffer::Id id = 0;
583          id < buffer_liveness.points_to_analysis().num_logical_buffers();
584          id++) {
585       auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
586       buffer.set_color(LogicalBuffer::Color(color++));
587     }
588     return Status::OK();
589   };
590 
591   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
592 
593   // Distinct input buffers were assigned for parameters.
594   BufferAllocation paramscalar_buffer =
595       GetAssignedInputAllocation(*buffers, paramscalar);
596   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
597   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
598   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
599   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
600   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
601 
602   // The mul node has a valid buffer assigned, doesn't share with input.
603   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
604   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
605 
606   // The add node can not reuse the mul node's buffer due to coloring.
607   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
608   EXPECT_NE(add_buffer.index(), mul_buffer.index());
609 
610   // The sub node has a valid output buffer assigned.
611   GetAssignedOutputAllocation(*buffers, sub);
612 }
613 
TEST_F(BufferAssignmentTest,BasicPartiallyColored)614 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
615   // paramscalar ------- (mul) -- (add) -- (sub)
616   //                     /        /        /
617   // param0[100] -------/        /        /
618   //                            /        /
619   // param1[100] --------------/--------/
620   // The output of the mul and the add have the color 1, and the other buffers
621   // have the color 0, which allows the mul and add to share buffers.
622   auto builder = HloComputation::Builder(TestName());
623   auto paramscalar =
624       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
625   auto broadcast = builder.AddInstruction(
626       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
627   auto param0 = builder.AddInstruction(
628       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
629   auto param1 = builder.AddInstruction(
630       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
631   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
632       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
633   auto add = builder.AddInstruction(
634       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
635   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
636       f32vec100_, HloOpcode::kSubtract, add, param1));
637   auto module = CreateNewVerifiedModule();
638   module->AddEntryComputation(builder.Build());
639 
640   auto colorer = [](const BufferLiveness& buffer_liveness) {
641     for (LogicalBuffer::Id id = 0;
642          id < buffer_liveness.points_to_analysis().num_logical_buffers();
643          id++) {
644       auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
645       const auto& aliases =
646           buffer_liveness.points_to_analysis().GetBufferAliases(buffer);
647       for (const auto& alias : aliases) {
648         if (alias.instruction()->opcode() == HloOpcode::kAdd ||
649             alias.instruction()->opcode() == HloOpcode::kMultiply) {
650           buffer.set_color(LogicalBuffer::Color(1));
651         }
652       }
653       if (!buffer.has_color()) {
654         buffer.set_color(LogicalBuffer::Color(0));
655       }
656     }
657     return Status::OK();
658   };
659 
660   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
661 
662   // Distinct input buffers were assigned for parameters.
663   BufferAllocation paramscalar_buffer =
664       GetAssignedInputAllocation(*buffers, paramscalar);
665   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
666   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
667   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
668   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
669   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
670 
671   // The mul node has a valid buffer assigned, doesn't share with input.
672   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
673   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
674 
675   // The add node can reuse the mul node's buffer.
676   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
677   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
678 
679   // The sub node has a valid output buffer assigned.
680   GetAssignedOutputAllocation(*buffers, sub);
681 }
682 
TEST_F(BufferAssignmentTest,MultipleUsersForNode)683 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
684   // This is similar to the Basic test, with the difference that (sub) is
685   // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
686   // (add)'s output.
687   //
688   // paramscalar -------\     /-----------\
689   //                     \   /             \
690   // param0[100] ------- (mul) -- (add) -- (sub)
691   //                              /
692   // param1[100] ----------------/
693   //
694   auto builder = HloComputation::Builder(TestName());
695   auto paramscalar =
696       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
697   auto broadcast = builder.AddInstruction(
698       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
699   auto param0 = builder.AddInstruction(
700       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
701   auto param1 = builder.AddInstruction(
702       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
703   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
704       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
705   auto add = builder.AddInstruction(
706       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
707   auto sub = builder.AddInstruction(
708       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
709   auto module = CreateNewVerifiedModule();
710   module->AddEntryComputation(builder.Build());
711 
712   auto buffers = RunBufferAssignment(module.get());
713 
714   // Input buffers were assigned for parameters.
715   BufferAllocation paramscalar_buffer =
716       GetAssignedInputAllocation(*buffers, paramscalar);
717   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
718   BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
719   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
720   EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
721   EXPECT_NE(param0_buffer.index(), param1_index.index());
722 
723   // The mul node had a buffer allocated.
724   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
725 
726   // Now the add node can't reuse the mul node's buffer.
727   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
728   EXPECT_NE(add_buffer.index(), mul_buffer.index());
729 
730   // Log size information for inspection.
731   const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
732   int64 size0 = ValidateBuffers(level0, *buffers);
733   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
734             << " for " << level0.size() << " instructions; "
735             << "total buffer size " << size0;
736 }
737 
TEST_F(BufferAssignmentTest,TrivialMap)738 TEST_F(BufferAssignmentTest, TrivialMap) {
739   // This tests a trivial x+1 map as the only operation.
740   //
741   // param0[100x10] ---> (map x+1)
742   //
743   // Builds the map function.
744   auto module = CreateNewVerifiedModule();
745   auto map_computation =
746       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
747   auto inner_last = map_computation->root_instruction();
748 
749   // Creates the main kernel and verifies instruction counts.
750   auto builder = HloComputation::Builder(TestName());
751   auto param0 = builder.AddInstruction(
752       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
753   auto map = builder.AddInstruction(
754       HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
755   module->AddEntryComputation(builder.Build());
756 
757   const std::vector<const HloInstruction*> level0 = GetInstructions(map);
758   EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
759   const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
760   EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
761 
762   // Assigns buffers and fetches sizes.
763   auto buffers = RunBufferAssignment(module.get());
764   int64 size0 = ValidateBuffers(level0, *buffers);
765   int64 size1 = ValidateBuffers(level1, *buffers);
766 
767   // Both algorithms assign the map's buffer before processing the embedded
768   // computation, so we can verify that the buffers aren't shared between them
769   // by checking:
770   EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
771       << "Reuse between main kernel and embedded mapping.";
772 
773   // An input buffer was assigned for the parameter.
774   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
775 
776   // An output buffer was assigned for the map.
777   BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
778   EXPECT_NE(param0_buffer.index(), map_buffer.index());
779 
780   // The final computation node of the map is an add of an f32 param and a
781   // constant.
782   EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
783   const BufferAllocation& inner_add_buffer =
784       GetTopLevelAllocation(*buffers, inner_last);
785   EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
786 
787   // Log size information for inspection.
788   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
789             << " for " << level0.size() + level1.size() << " instructions; "
790             << "total buffer size " << size0 + size1;
791 }
792 
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)793 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
794   // Make sure that the input buffer of a reduce cannot be reused for its
795   // output.  (Reuse is not safe in the general case, as it reshapes and some
796   // out-of-order reductions could overwrite an element before a use.)
797   //
798   // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
799   auto module = CreateNewVerifiedModule();
800   auto reduce_computation =
801       module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
802 
803   auto builder = HloComputation::Builder(TestName());
804   auto param0 = builder.AddInstruction(
805       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
806   auto exp1 = builder.AddInstruction(
807       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
808   auto exp2 = builder.AddInstruction(
809       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
810   auto const0 = builder.AddInstruction(
811       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
812   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
813       /*shape=*/f32vec10_,
814       /*operand=*/exp2,
815       /*init_value=*/const0,
816       /*dimensions_to_reduce=*/{0}, reduce_computation));
817   auto exp3 = builder.AddInstruction(
818       HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
819 
820   module->AddEntryComputation(builder.Build());
821 
822   auto buffers = RunBufferAssignment(module.get());
823   const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
824   ValidateBuffers(instrs, *buffers);
825 
826   const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
827   const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
828   const BufferAllocation& reduce_buffer =
829       GetTopLevelAllocation(*buffers, reduce);
830 
831   // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
832   // checking.
833   EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
834 
835   // The buffer of exp2 cannot be used for reduce, even though it's the only
836   // operand.
837   EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
838 }
839 
TEST_F(BufferAssignmentTest,ExampleWhile)840 TEST_F(BufferAssignmentTest, ExampleWhile) {
841   // This tests a While loop example from the ir_semantics document.
842   //
843   // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
844   // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
845   //
846   // const3[s32] -------\
847   // const4[f32[4]] --- tuple --- while[condition, body]
848   //
849   // Builds the nested condition and body.
850   auto module = CreateNewVerifiedModule();
851   auto condition_computation =
852       module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
853   auto body_computation =
854       module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
855 
856   // Creates the main kernel and verifies instruction counts.
857   auto builder = HloComputation::Builder(TestName());
858   auto const3 = builder.AddInstruction(
859       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
860   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
861       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
862   auto tuple =
863       builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
864   auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
865       t_s32_f32v4_, condition_computation, body_computation, tuple));
866   module->AddEntryComputation(builder.Build());
867 
868   const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
869   EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
870   const std::vector<const HloInstruction*> levelc =
871       GetInstructions(condition_computation->root_instruction());
872   EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
873   const std::vector<const HloInstruction*> levelb =
874       GetInstructions(body_computation->root_instruction());
875   EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
876 
877   // Assigns buffers and fetches sizes.
878   auto buffers = RunBufferAssignment(module.get());
879   int64 size0 = ValidateBuffers(level0, *buffers);
880   int64 sizec = ValidateBuffers(levelc, *buffers);
881   int64 sizeb = ValidateBuffers(levelb, *buffers);
882 
883   // BufferAssignment will assign a single allocation for the following
884   // instructions: while, while.cond.param, while.body.param, while.body.result.
885   EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
886       << "Should be reuse between main kernel and embedded condition.";
887   EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
888       << "Should be reuse between embedded condition and body.";
889   // Expect buffer reuse between main kernel and body computation.
890   EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
891       << "Should be reuse between main kernel and embedded body.";
892 
893   // The final computation node of the while body is a tuple of s32 and
894   // f32[4] adds.
895   HloInstruction* body_root = body_computation->root_instruction();
896   EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
897 
898   // Check that buffer for each subshape of 'while_op' shares allocation with
899   // corresponding buffer from while body computation at same index.
900   ShapeUtil::ForEachSubshape(
901       while_op->shape(),
902       [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
903                                             const ShapeIndex& index) {
904         auto while_op_allocation = GetAllocation(*buffers, while_op, index);
905         auto body_root_allocation = GetAllocation(*buffers, body_root, index);
906         EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
907       });
908 
909   // Log size information for inspection.
910   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
911             << " for " << level0.size() + levelc.size() + levelb.size()
912             << " instructions; total buffer size " << size0 + sizec + sizeb;
913 }
914 
TEST_F(BufferAssignmentTest,ExampleConditional)915 TEST_F(BufferAssignmentTest, ExampleConditional) {
916   auto module = CreateNewVerifiedModule();
917   auto true_computation = module->AddEmbeddedComputation(
918       BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
919   auto false_computation = module->AddEmbeddedComputation(
920       BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
921 
922   auto builder = HloComputation::Builder(TestName());
923   auto pred = builder.AddInstruction(
924       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
925   auto const1 = builder.AddInstruction(
926       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
927   auto const2 = builder.AddInstruction(
928       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
929   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
930       r0f32_, pred, const1, true_computation, const2, false_computation));
931   module->AddEntryComputation(builder.Build());
932 
933   const std::vector<const HloInstruction*> conditional_instrs =
934       GetInstructions(conditional);
935   const std::vector<const HloInstruction*> true_instrs =
936       GetInstructions(true_computation->root_instruction());
937   const std::vector<const HloInstruction*> false_instrs =
938       GetInstructions(false_computation->root_instruction());
939   EXPECT_EQ(4, conditional_instrs.size());
940   EXPECT_EQ(2, true_instrs.size());
941   EXPECT_EQ(2, false_instrs.size());
942 
943   auto buffers = RunBufferAssignment(module.get());
944   ValidateBuffers(conditional_instrs, *buffers);
945   ValidateBuffers(true_instrs, *buffers);
946   ValidateBuffers(false_instrs, *buffers);
947 
948   EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
949       << "Should be reuse between conditional and true computation.";
950   EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
951       << "Should be reuse between conditional and false computation.";
952   EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
953       << "Should be reuse between true and false computations.";
954 
955   const BufferAllocation& conditional_buffer =
956       GetTopLevelAllocation(*buffers, conditional);
957   const BufferAllocation& true_buffer =
958       GetTopLevelAllocation(*buffers, true_computation->root_instruction());
959   const BufferAllocation& false_buffer =
960       GetTopLevelAllocation(*buffers, false_computation->root_instruction());
961   EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
962   EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
963 }
964 
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)965 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
966   // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
967   auto builder = HloComputation::Builder(TestName());
968   auto param0 = builder.AddInstruction(
969       HloInstruction::CreateParameter(0, f32vec100_, "p"));
970   auto exp1 = builder.AddInstruction(
971       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
972   auto tanh = builder.AddInstruction(
973       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
974   auto exp2 = builder.AddInstruction(
975       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
976   auto neg = builder.AddInstruction(
977       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
978 
979   auto module = CreateNewVerifiedModule();
980   module->AddEntryComputation(builder.Build());
981   auto assignment = RunBufferAssignment(module.get());
982 
983   // tanh and exp2 can reuse exp1's buffer
984   EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
985   auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
986   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
987   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
988   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
989 }
990 
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)991 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
992   // This computation is a chain of operations which decreases in buffer size
993   // (via slice) then increases in size (via broadcast):
994   //
995   // param ---> (negate) ---> (slice) ---> (broadcast)
996   //
997   // The negate should share a buffer with broadcast.
998   auto builder = HloComputation::Builder(TestName());
999   auto param0 = builder.AddInstruction(
1000       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1001   auto negate = builder.AddInstruction(
1002       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1003   auto slice = builder.AddInstruction(
1004       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1005   auto broadcast = builder.AddInstruction(
1006       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1007 
1008   auto module = CreateNewVerifiedModule();
1009   module->AddEntryComputation(builder.Build());
1010   auto assignment = RunBufferAssignment(module.get());
1011 
1012   // negate and broadcast should share a buffer.
1013   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1014   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1015   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1016 
1017   // Slice should have its own buffer.
1018   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1019 }
1020 
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1021 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1022   // This computation is identical to that in ReuseNonOperandBuffer, but the
1023   // negate value is live until the end of the computation (due to it being an
1024   // operand of the output tuple) preventing reuse.
1025   //
1026   // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1027   //                  \-----------------------------------/
1028   //
1029   // The negate should not share a buffer with broadcast.
1030   auto builder = HloComputation::Builder(TestName());
1031   auto param0 = builder.AddInstruction(
1032       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1033   auto negate = builder.AddInstruction(
1034       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1035   auto slice = builder.AddInstruction(
1036       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1037   auto broadcast = builder.AddInstruction(
1038       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1039   builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1040 
1041   auto module = CreateNewVerifiedModule();
1042   module->AddEntryComputation(builder.Build());
1043   auto assignment = RunBufferAssignment(module.get());
1044 
1045   // The instructions should not share buffers.
1046   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1047             GetTopLevelAllocation(*assignment, negate));
1048   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1049             GetTopLevelAllocation(*assignment, slice));
1050   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1051             GetTopLevelAllocation(*assignment, slice));
1052 }
1053 
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1054 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1055   // This computation is identical to that in ReuseNonOperandBuffer, but the
1056   // negate value is placed into a tuple which lives to the end of the
1057   // computation. This extends the live range of negate's buffer preventing
1058   // reuse due to buffer aliasing.
1059   //
1060   // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1061   //                              \-----------------------------------/
1062   //
1063   // The negate should not share a buffer with broadcast.
1064   auto builder = HloComputation::Builder(TestName());
1065   auto param0 = builder.AddInstruction(
1066       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1067   auto negate = builder.AddInstruction(
1068       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1069   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1070   auto tuple_element = builder.AddInstruction(
1071       HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1072   auto slice = builder.AddInstruction(
1073       HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1074   auto broadcast = builder.AddInstruction(
1075       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1076   builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1077 
1078   auto module = CreateNewVerifiedModule();
1079   module->AddEntryComputation(builder.Build());
1080   auto assignment = RunBufferAssignment(module.get());
1081 
1082   // The instructions should not share buffers.
1083   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1084             GetTopLevelAllocation(*assignment, negate));
1085   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1086             GetTopLevelAllocation(*assignment, slice));
1087   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1088             GetTopLevelAllocation(*assignment, slice));
1089 }
1090 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1091 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1092   // This computation is very similar to ReuseNonOperandBuffer except the
1093   // broadcast has a smaller output than the negate. This should block reuse of
1094   // negate's buffer by broadcast because the output buffer(s) of a computation
1095   // should be exactly sized for the value.
1096   //
1097   // param ---> (negate) ---> (slice) ---> (broadcast)
1098   //
1099   // Neither negate nor slice may share a buffer with broadcast.
1100   auto builder = HloComputation::Builder(TestName());
1101   auto param0 = builder.AddInstruction(
1102       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1103   // Negate output is 100 elements.
1104   auto negate = builder.AddInstruction(
1105       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1106   // Slice output is 10 elements.
1107   auto slice = builder.AddInstruction(
1108       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1109   // Broadcast output is 40 elements.
1110   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1111       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1112 
1113   auto module = CreateNewVerifiedModule();
1114   module->AddEntryComputation(builder.Build());
1115   auto assignment = RunBufferAssignment(module.get());
1116 
1117   // The broadcast output buffer cannot be shared.
1118   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1119             GetTopLevelAllocation(*assignment, negate));
1120   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1121             GetTopLevelAllocation(*assignment, slice));
1122 }
1123 
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1124 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1125   // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1126   // output is exactly the same size as the negate (rather than being
1127   // smaller). This enables reuse of negate's buffer by the broadcast because
1128   // the output buffer will be sized exactly to its value.
1129   //
1130   // param ---> (negate) ---> (slice) ---> (broadcast)
1131   //
1132   // The negate should *not* share a buffer with broadcast.
1133   auto builder = HloComputation::Builder(TestName());
1134   auto param0 = builder.AddInstruction(
1135       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1136   // Negate output is 100 elements.
1137   auto negate = builder.AddInstruction(
1138       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1139   auto slice = builder.AddInstruction(
1140       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1141   // Broadcast output is 40 elements.
1142   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1143       ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1144 
1145   auto module = CreateNewVerifiedModule();
1146   module->AddEntryComputation(builder.Build());
1147   auto assignment = RunBufferAssignment(module.get());
1148 
1149   // negate and broadcast should share a buffer.
1150   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1151   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1152   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1153 
1154   // Slice should have its own buffer.
1155   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1156 }
1157 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1158 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1159   // This computation is very similar to ReuseNonOperandBuffer except the
1160   // broadcast has a smaller output than the negate, and the broadcast is
1161   // contained in the computation output as a tuple element. This should block
1162   // reuse of the negate's buffer by the broadcast because the output buffer(s)
1163   // of a computation should be exactly sized for the value. This includes those
1164   // buffers aliased in the output (eg, contained as tuple elements).
1165   //
1166   // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1167   //
1168   // Neither negate nor slice may share a buffer with broadcast.
1169   auto builder = HloComputation::Builder(TestName());
1170   auto param0 = builder.AddInstruction(
1171       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1172   // Negate output is 100 elements.
1173   auto negate = builder.AddInstruction(
1174       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1175   // Slice output is 10 elements.
1176   auto slice = builder.AddInstruction(
1177       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1178   // Broadcast output is 40 elements.
1179   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1180       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1181   builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1182 
1183   auto module = CreateNewVerifiedModule();
1184   module->AddEntryComputation(builder.Build());
1185   auto assignment = RunBufferAssignment(module.get());
1186 
1187   // The broadcast output buffer cannot be shared.
1188   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1189             GetTopLevelAllocation(*assignment, negate));
1190   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1191             GetTopLevelAllocation(*assignment, slice));
1192 }
1193 
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1194 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1195   // Verify that buffers for embedded computations are properly marked as
1196   // thread-local and that embedded parameters are not marked as
1197   // is_entry_computation_parameter.
1198   auto module = CreateNewVerifiedModule();
1199   auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1200   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1201 
1202   // Create a scalar computation to use in a map.
1203   auto map_builder = HloComputation::Builder(TestName() + "_map");
1204   auto map_param = map_builder.AddInstruction(
1205       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1206   auto map_root = map_builder.AddInstruction(
1207       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1208   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1209 
1210   // Create a vector computation to use in a kCall.
1211   auto call_builder = HloComputation::Builder(TestName() + "_call");
1212   auto call_param = call_builder.AddInstruction(
1213       HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1214   auto call_root = call_builder.AddInstruction(
1215       HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1216   auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1217 
1218   // Create entry computation which kCalls call_computation and then calls map
1219   // with map_computation on the result.
1220   auto builder = HloComputation::Builder(TestName());
1221   auto param = builder.AddInstruction(
1222       HloInstruction::CreateParameter(0, vec_shape, "param"));
1223   auto call = builder.AddInstruction(
1224       HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1225   auto map = builder.AddInstruction(
1226       HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1227   module->AddEntryComputation(builder.Build());
1228 
1229   auto assignment = RunBufferAssignment(module.get());
1230 
1231   // Allocations for the map computation should be thread-local and not
1232   // live-out.
1233   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1234   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1235   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1236   EXPECT_TRUE(map_param_alloc.is_thread_local());
1237 
1238   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1239   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1240   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1241   EXPECT_TRUE(map_root_alloc.is_thread_local());
1242 
1243   // Allocations for the call computation should not be thread-local.
1244   auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1245   EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1246   EXPECT_FALSE(call_param_alloc.maybe_live_out());
1247   EXPECT_FALSE(call_param_alloc.is_thread_local());
1248 
1249   auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1250   EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1251   EXPECT_FALSE(call_root_alloc.is_thread_local());
1252 
1253   // Entry computation allocations can be marked liveout and
1254   // is_entry_computation_parameter.
1255   auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1256   EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1257   EXPECT_FALSE(param_alloc.maybe_live_out());
1258   EXPECT_FALSE(param_alloc.is_thread_local());
1259 
1260   auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1261   EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1262   EXPECT_TRUE(map_alloc.maybe_live_out());
1263   EXPECT_FALSE(map_alloc.is_thread_local());
1264 }
1265 
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1266 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1267   // Test a computation that returns a tuple parameter.
1268   auto builder = HloComputation::Builder(TestName());
1269   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1270       0,
1271       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1272                                  ShapeUtil::MakeShape(F32, {}),
1273                                  ShapeUtil::MakeShape(S32, {42})}),
1274       "param0"));
1275 
1276   auto module = CreateNewVerifiedModule();
1277   module->AddEntryComputation(builder.Build());
1278   auto assignment = RunBufferAssignment(module.get());
1279 
1280   // There should be four allocations: one for vector of pointers, and one for
1281   // each tuple element.
1282   EXPECT_EQ(4, assignment->Allocations().size());
1283 
1284   // Verify each buffer allocation is marked as an entry computation parameter
1285   // and is liveout.
1286   ShapeUtil::ForEachSubshape(
1287       tuple_param->shape(),
1288       [this, &assignment, tuple_param](const Shape& /*subshape*/,
1289                                        const ShapeIndex& index) {
1290         auto allocation = GetAllocation(*assignment, tuple_param, index);
1291         EXPECT_TRUE(allocation.is_entry_computation_parameter());
1292         EXPECT_EQ(0, allocation.parameter_number());
1293         EXPECT_TRUE(allocation.maybe_live_out());
1294       });
1295 }
1296 
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1297 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1298   // Test a computation which returns a GetElementTuple of a nested tuple
1299   // parameter.
1300   auto builder = HloComputation::Builder(TestName());
1301   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1302       0,
1303       ShapeUtil::MakeTupleShape(
1304           {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1305            ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1306                                       ShapeUtil::MakeShape(S32, {101})})}),
1307       "param0"));
1308   auto tuple_element =
1309       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1310           ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1311 
1312   auto module = CreateNewVerifiedModule();
1313   module->AddEntryComputation(builder.Build());
1314   auto assignment = RunBufferAssignment(module.get());
1315 
1316   // Only some of the elements of the input param are liveout.
1317   EXPECT_FALSE(
1318       GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1319   // Tuple element at index={1} is live out because GetTupleElement({1})
1320   // forwards a pointer to this allocation (instead of defining its own buffer).
1321   EXPECT_TRUE(
1322       GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1323   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1324                   .maybe_live_out());
1325   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1326                   .maybe_live_out());
1327 
1328   // The GetTupleElement output is liveout.
1329   EXPECT_TRUE(
1330       GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1331 
1332   // Verify that the GetTupleElement allocations of its elements match the
1333   // corresponding tuple parameter allocations because they alias.
1334   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1335             GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1336   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1337             GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1338 
1339   // GetTupleElement forwards a pointer to its underlying buffer, so verify
1340   // that it has the same allocation than the corresponding parameter element.
1341   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1342             GetTopLevelAllocation(*assignment, tuple_element));
1343 }
1344 
1345 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1346 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1347 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1348   // Test that a tuple constant which is forwarded to the computation output
1349   // is properly handled.
1350   auto builder = HloComputation::Builder(TestName());
1351   Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1352                         LiteralUtil::CreateR0<int64>(1)};
1353   builder.AddInstruction(HloInstruction::CreateConstant(
1354       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1355 
1356   auto module = CreateNewVerifiedModule();
1357   module->AddEntryComputation(builder.Build());
1358   auto assignment = RunBufferAssignment(module.get());
1359 
1360   EXPECT_EQ(3, assignment->Allocations().size());
1361 }
1362 
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1363 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1364   // Test a computation which returns a tuple custom call value.
1365   auto builder = HloComputation::Builder(TestName());
1366   auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1367       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1368                                  ShapeUtil::MakeShape(S32, {101})}),
1369       /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1370   auto module = CreateNewVerifiedModule();
1371   module->AddEntryComputation(builder.Build());
1372   auto assignment = RunBufferAssignment(module.get());
1373 
1374   EXPECT_EQ(3, assignment->Allocations().size());
1375   EXPECT_TRUE(
1376       GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1377   EXPECT_TRUE(
1378       GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1379   EXPECT_TRUE(
1380       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1381 }
1382 
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1383 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1384   // Test a computation which returns a tuple call value.
1385   auto module = CreateNewVerifiedModule();
1386   auto elem_shape = f32vec4_;
1387   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1388 
1389   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1390   auto sub_param = sub_builder.AddInstruction(
1391       HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1392   auto sub_tuple =
1393       sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1394   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1395 
1396   auto builder = HloComputation::Builder(TestName());
1397   auto param = builder.AddInstruction(
1398       HloInstruction::CreateParameter(0, elem_shape, "param"));
1399   auto call = builder.AddInstruction(
1400       HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1401   module->AddEntryComputation(builder.Build());
1402 
1403   auto assignment = RunBufferAssignment(module.get());
1404 
1405   EXPECT_EQ(2, assignment->Allocations().size());
1406   // Buffers for call are colocated with the sub-computation.
1407   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1408             GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1409   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1410             GetAllocation(*assignment, sub_param, /*index=*/{}));
1411 
1412   // The parameter isn't aliased with the result tuple, but it is aliased with
1413   // the call operand.
1414   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1415             GetTopLevelAllocation(*assignment, sub_tuple));
1416   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1417             GetTopLevelAllocation(*assignment, sub_param));
1418 }
1419 
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1420 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1421   // Test a chain of calls with tuple output. The chain looks like:
1422   // A: call(B, tuple(param))
1423   // B: call(C, param)
1424   // C: call(D, param)
1425   // D: param
1426   auto module = CreateNewVerifiedModule();
1427   auto elem_shape = f32vec4_;
1428   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1429 
1430   auto d_builder = HloComputation::Builder(TestName() + "_d");
1431   auto d_param = d_builder.AddInstruction(
1432       HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1433   auto d_computation = d_builder.Build();
1434 
1435   auto c_builder = HloComputation::Builder(TestName() + "_c");
1436   auto c_param = c_builder.AddInstruction(
1437       HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1438   auto c_call = c_builder.AddInstruction(
1439       HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1440   auto c_computation = c_builder.Build();
1441 
1442   auto b_builder = HloComputation::Builder(TestName() + "_b");
1443   auto b_param = b_builder.AddInstruction(
1444       HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1445   auto b_call = b_builder.AddInstruction(
1446       HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1447   auto b_computation = b_builder.Build();
1448 
1449   auto a_builder = HloComputation::Builder(TestName());
1450   auto a_param = a_builder.AddInstruction(
1451       HloInstruction::CreateParameter(0, elem_shape, "param"));
1452   auto a_tuple =
1453       a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1454   auto a_call = a_builder.AddInstruction(
1455       HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1456   auto a_computation = a_builder.Build();
1457 
1458   // Add the computations in an order that doesn't match the dependency
1459   // post-order, to shake out more possible bugs.
1460   module->AddEmbeddedComputation(std::move(d_computation));
1461   module->AddEmbeddedComputation(std::move(c_computation));
1462   module->AddEntryComputation(std::move(a_computation));
1463   module->AddEmbeddedComputation(std::move(b_computation));
1464 
1465   auto assignment = RunBufferAssignment(module.get());
1466 
1467   // Buffers for call are colocated with the sub-computations.
1468   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1469             GetAllocation(*assignment, b_call, /*index=*/{}));
1470   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1471             GetAllocation(*assignment, c_call, /*index=*/{}));
1472   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1473             GetAllocation(*assignment, d_param, /*index=*/{}));
1474   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1475             GetAllocation(*assignment, b_call, /*index=*/{0}));
1476   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1477             GetAllocation(*assignment, c_call, /*index=*/{0}));
1478   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1479             GetAllocation(*assignment, d_param, /*index=*/{0}));
1480 
1481   EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1482   EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1483   EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1484 
1485   EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1486             GetAllocation(*assignment, c_param, /*index=*/{0}));
1487   EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1488             GetAllocation(*assignment, d_param, /*index=*/{0}));
1489 }
1490 
TEST_F(BufferAssignmentTest,BitcastAsOutput)1491 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1492   // Test a computation which returns a bitcast value.
1493   auto builder = HloComputation::Builder(TestName());
1494   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1495       0, ShapeUtil::MakeShape(F32, {42}), "param"));
1496   auto bitcast = builder.AddInstruction(
1497       HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param));
1498 
1499   auto module = CreateNewVerifiedModule();
1500   module->AddEntryComputation(builder.Build());
1501   auto assignment = RunBufferAssignment(module.get());
1502 
1503   // Bitcast should get the same allocation as the param.
1504   EXPECT_EQ(1, assignment->Allocations().size());
1505   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1506             GetTopLevelAllocation(*assignment, bitcast));
1507 }
1508 
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1509 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1510   // Test a computation with an output that has an ambiguous points-to set.
1511   // This is constructed using a select among tuple shapes.
1512   auto builder = HloComputation::Builder(TestName());
1513   auto tuple_shape =
1514       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1515 
1516   auto tuple_param0 = builder.AddInstruction(
1517       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1518   auto tuple_param1 = builder.AddInstruction(
1519       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1520   auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1521       2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1522   auto select = builder.AddInstruction(
1523       HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1524                                     pred_param, tuple_param0, tuple_param1));
1525 
1526   auto module = CreateNewVerifiedModule();
1527   module->AddEntryComputation(builder.Build());
1528   auto assignment = RunBufferAssignment(module.get());
1529 
1530   // Select shallow copies one of its operands so it defines its own top-level
1531   // buffer and receives its own allocation.
1532   auto select_alloc = GetTopLevelAllocation(*assignment, select);
1533   EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1534   EXPECT_EQ(select,
1535             select_alloc.assigned_buffers().begin()->first->instruction());
1536 
1537   // The buffer for the tuple element of the select is forwarded from one its
1538   // operands which cannot be determined statically. Therefore its slices
1539   // should include the slices of both of the elements in the parameters.
1540   auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1541   EXPECT_EQ(2, element_slices.size());
1542   EXPECT_THAT(element_slices,
1543               UnorderedElementsAre(
1544                   assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1545                       .ConsumeValueOrDie(),
1546                   assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1547                       .ConsumeValueOrDie()));
1548 }
1549 
1550 // TODO(b/34669761): Remove this test when buffers are allowed to share
1551 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1552 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1553   // Test a computation that returns a tuple parameter.
1554   auto builder = HloComputation::Builder(TestName());
1555   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1556   auto param = builder.AddInstruction(
1557       HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1558   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1559   auto tuple_element = builder.AddInstruction(
1560       HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1561   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1562       scalar_shape, HloOpcode::kCopy, tuple_element));
1563 
1564   auto module = CreateNewVerifiedModule();
1565   module->AddEntryComputation(builder.Build());
1566   auto assignment = RunBufferAssignment(module.get());
1567 
1568   // There should be no buffer reuse. The copy should not reuse the tuple
1569   // buffer.
1570   EXPECT_EQ(3, assignment->Allocations().size());
1571   EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1572             GetTopLevelAllocation(*assignment, copy));
1573 }
1574 
TEST_F(BufferAssignmentTest,OneTempAllocation)1575 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1576   // Test a computation that requires multiple temp buffers, and ensure they
1577   // are combined into a single allocation.
1578   auto builder = HloComputation::Builder(TestName());
1579   Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1580   Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1581   Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1582   Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1583   Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1584 
1585   // There should be separate temp buffers for dot_ab and dot_bc.
1586   auto param_a = builder.AddInstruction(
1587       HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1588   auto param_b = builder.AddInstruction(
1589       HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1590   auto param_c = builder.AddInstruction(
1591       HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1592   DotDimensionNumbers dot_dnums;
1593   dot_dnums.add_lhs_contracting_dimensions(1);
1594   dot_dnums.add_rhs_contracting_dimensions(0);
1595   PrecisionConfig precision_config;
1596   precision_config.mutable_operand_precision()->Resize(
1597       2, PrecisionConfig::DEFAULT);
1598   auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1599       shape_2x4, param_a, param_b, dot_dnums, precision_config));
1600   auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1601       shape_3x4, param_b, param_c, dot_dnums, precision_config));
1602   builder.AddInstruction(
1603       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1604 
1605   // Run buffer assignment with alignment=1.
1606   auto module = CreateNewVerifiedModule();
1607   module->AddEntryComputation(builder.Build());
1608   auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1609 
1610   // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1611   EXPECT_EQ(5, assignment->Allocations().size());
1612 
1613   // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1614   // and each occupies different slices of that allocation.
1615   BufferAllocation::Slice slice_ab =
1616       assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1617   BufferAllocation::Slice slice_bc =
1618       assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1619   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1620   EXPECT_NE(slice_ab, slice_bc);
1621   EXPECT_EQ(32, slice_ab.size());
1622   EXPECT_EQ(48, slice_bc.size());
1623   EXPECT_EQ(80, slice_ab.allocation()->size());
1624   EXPECT_EQ(80, slice_bc.allocation()->size());
1625 
1626   // Re-run buffer assignment with alignment=64.
1627   assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1628   EXPECT_EQ(5, assignment->Allocations().size());
1629   slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1630   slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1631   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1632   EXPECT_NE(slice_ab, slice_bc);
1633   EXPECT_EQ(32, slice_ab.size());
1634   EXPECT_EQ(48, slice_bc.size());
1635   // Ensure the offsets and allocation size account for the alignment, without
1636   // assuming which buffer gets assigned first.
1637   if (slice_ab.offset() == 0) {
1638     EXPECT_EQ(64, slice_bc.offset());
1639     EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1640     EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1641   } else {
1642     EXPECT_EQ(64, slice_ab.offset());
1643     EXPECT_EQ(0, slice_bc.offset());
1644     EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1645     EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1646   }
1647 }
1648 
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1649 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1650   // paramscalar ------- (mul) -- (add) -- (sub)
1651   //                     /        /        /
1652   // param0[100] -------/        /        /
1653   //                            /        /
1654   // param1[100] --------------/--------/
1655   auto builder = HloComputation::Builder(TestName());
1656   auto paramscalar =
1657       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1658   auto broadcast = builder.AddInstruction(
1659       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1660   auto param0 = builder.AddInstruction(
1661       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1662   auto param1 = builder.AddInstruction(
1663       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1664   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1665       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1666   auto add = builder.AddInstruction(
1667       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1668   builder.AddInstruction(HloInstruction::CreateBinary(
1669       f32vec100_, HloOpcode::kSubtract, add, param1));
1670   auto module = CreateNewVerifiedModule();
1671   module->AddEntryComputation(builder.Build());
1672 
1673   auto buffers = RunBufferAssignment(module.get());
1674 
1675   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1676   const std::vector<const LogicalBuffer*>& peak_buffers =
1677       mul_buffer.PeakMemoryLogicalBuffers();
1678   ASSERT_EQ(peak_buffers.size(), 1);
1679   EXPECT_EQ(peak_buffers[0]->instruction(), broadcast);
1680 }
1681 
TEST_F(BufferAssignmentTest,PeakBuffers)1682 TEST_F(BufferAssignmentTest, PeakBuffers) {
1683   // Compute the peak liveness buffers of the following sequence:
1684   //
1685   //   %param = ...
1686   //   %log = log(%param)
1687   //   %rev = reverse(%log)
1688   //   %neg = neg(%param)
1689   //   %concat = concat(%rev, %neg)
1690   //   ROOT %root = slice(concat)
1691   //
1692   // In the temporary block, the set of live buffers at peak memory use should
1693   // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1694   auto builder = HloComputation::Builder(TestName());
1695   auto param = builder.AddInstruction(
1696       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1697   auto log = builder.AddInstruction(
1698       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1699   auto rev = builder.AddInstruction(
1700       HloInstruction::CreateReverse(f32vec100_, log, {0}));
1701   auto neg = builder.AddInstruction(
1702       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1703   const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1704   auto concat = builder.AddInstruction(
1705       HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1706   // Make the root tiny so no interior nodes can share its buffer.
1707   auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1708 
1709       ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1710 
1711   auto module = CreateNewVerifiedModule();
1712   module->AddEntryComputation(builder.Build());
1713 
1714   auto buffers = RunBufferAssignmentWithInstructionSequence(
1715       module.get(), {param, log, rev, neg, concat, root});
1716 
1717   // The temporary buffer should hold the 4 interior instructions.
1718   const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1719   EXPECT_FALSE(buffer.IsInputOrOutput());
1720   EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1721   ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1722 
1723   const std::vector<const LogicalBuffer*>& peak_buffers =
1724       buffer.PeakMemoryLogicalBuffers();
1725 
1726   // The peak live set should be concat and its inputs.
1727   ASSERT_EQ(peak_buffers.size(), 3);
1728   std::vector<const HloInstruction*> peak_instructions;
1729   for (const LogicalBuffer* logical_buffer : peak_buffers) {
1730     peak_instructions.push_back(logical_buffer->instruction());
1731   }
1732   EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1733 }
1734 
TEST_F(BufferAssignmentTest,PeakBuffersWhile)1735 TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
1736   auto module = CreateNewVerifiedModule();
1737   const Shape shape = ShapeUtil::MakeShape(F32, {123, 123});
1738   HloComputation* condition;
1739   {
1740     auto b = HloComputation::Builder(TestName() + ".cond");
1741     b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
1742     b.AddInstruction(
1743         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1744     condition = module->AddEmbeddedComputation(b.Build());
1745   }
1746   HloComputation* body;
1747   {
1748     auto b = HloComputation::Builder(TestName() + ".body");
1749     auto param =
1750         b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
1751     b.AddInstruction(
1752         HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
1753     body = module->AddEmbeddedComputation(b.Build());
1754   }
1755   auto builder = HloComputation::Builder(TestName());
1756   auto param =
1757       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1758   auto copy = builder.AddInstruction(
1759       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param));
1760   auto while_op = builder.AddInstruction(
1761       HloInstruction::CreateWhile(shape, condition, body, copy));
1762   // This broadcast should get a temporary allocation which is merged with the
1763   // allocation for the while. Peak buffers should include the while and the
1764   // broadcast.
1765   auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1766       ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1}));
1767   builder.AddInstruction(HloInstruction::CreateReverse(
1768       ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0}));
1769   module->AddEntryComputation(builder.Build());
1770 
1771   auto buffers = RunBufferAssignment(module.get());
1772   const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast);
1773   const std::vector<const LogicalBuffer*>& peak_buffers =
1774       buffer.PeakMemoryLogicalBuffers();
1775   ASSERT_EQ(peak_buffers.size(), 2);
1776 
1777   // The peak buffers should include the broadcast and one of the colocated
1778   // buffers of the while (body param, condition param, body root, or the while
1779   // itself).
1780   const LogicalBuffer* bcast_buffer;
1781   const LogicalBuffer* nonbcast_buffer;
1782   if (peak_buffers[0]->instruction() == bcast) {
1783     bcast_buffer = peak_buffers[0];
1784     nonbcast_buffer = peak_buffers[1];
1785   } else {
1786     bcast_buffer = peak_buffers[1];
1787     nonbcast_buffer = peak_buffers[0];
1788   }
1789   EXPECT_EQ(bcast_buffer->instruction(), bcast);
1790   EXPECT_TRUE(
1791       nonbcast_buffer->instruction() == copy ||
1792       nonbcast_buffer->instruction() == while_op ||
1793       nonbcast_buffer->instruction() == body->parameter_instruction(0) ||
1794       nonbcast_buffer->instruction() == body->root_instruction() ||
1795       nonbcast_buffer->instruction() == condition->parameter_instruction(0));
1796 }
1797 
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1798 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1799   const char* hlo_text = R"(
1800 HloModule Module
1801 
1802 True {
1803   ROOT x.0.1 = f32[] parameter(0)
1804 }
1805 
1806 False {
1807   x.0.0 = f32[] parameter(0)
1808   ROOT copy.1 = f32[] copy(x.0.0)
1809 }
1810 
1811 ENTRY main {
1812   pred.1.0 = pred[] parameter(0)
1813   constant.1.1 = f32[] constant(56)
1814   copy.2 = f32[] copy(constant.1.1)
1815   constant.1.2 = f32[] constant(12)
1816   ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
1817       true_computation=True, false_computation=False
1818 }
1819 )";
1820 
1821   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1822   HloInstruction* constant_1 =
1823       m->entry_computation()->GetInstructionWithName("constant.1.1");
1824   HloInstruction* constant_2 =
1825       m->entry_computation()->GetInstructionWithName("constant.1.2");
1826 
1827   auto buffers = RunBufferAssignment(m.get());
1828 
1829   {
1830     const BufferAllocation& allocation_for_const_1 =
1831         GetTopLevelAllocation(*buffers, constant_1);
1832     EXPECT_TRUE(allocation_for_const_1.is_constant());
1833     for (const auto& buffer_offset_pair :
1834          allocation_for_const_1.assigned_buffers()) {
1835       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1836                 HloOpcode::kCopy);
1837       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1838                 HloOpcode::kConditional);
1839     }
1840   }
1841 
1842   {
1843     const BufferAllocation& allocation_for_const_2 =
1844         GetTopLevelAllocation(*buffers, constant_2);
1845     EXPECT_TRUE(allocation_for_const_2.is_constant());
1846     for (const auto& buffer_offset_pair :
1847          allocation_for_const_2.assigned_buffers()) {
1848       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1849                 HloOpcode::kCopy);
1850       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
1851                 HloOpcode::kConditional);
1852     }
1853   }
1854 }
1855 
1856 class WhileBufferAssignmentTest : public HloTestBase {
1857  protected:
BuildWhileConditionComputation(const string & name)1858   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
1859       const string& name) {
1860     auto builder = HloComputation::Builder(name);
1861     builder.AddInstruction(
1862         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
1863     auto zero = builder.AddInstruction(
1864         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1865     auto ten = builder.AddInstruction(
1866         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
1867     builder.AddInstruction(HloInstruction::CreateCompare(
1868         ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
1869     return builder.Build();
1870   }
1871 
BuildWhileBodyComputation(const string & name)1872   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
1873       const string& name) {
1874     auto builder = HloComputation::Builder(name);
1875     auto loop_state = builder.AddInstruction(
1876         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
1877     auto input = builder.AddInstruction(
1878         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
1879     auto weights = builder.AddInstruction(
1880         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
1881     auto output = builder.AddInstruction(HloInstruction::CreateBinary(
1882         data_shape_, HloOpcode::kMultiply, input, weights));
1883     builder.AddInstruction(
1884         HloInstruction::CreateTuple({input, weights, output}));
1885     return builder.Build();
1886   }
1887 
RunBufferAssignment(HloModule * module,int64 alignment=1)1888   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
1889                                                         int64 alignment = 1) {
1890     HloSchedule schedule =
1891         ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
1892     return BufferAssigner::Run(
1893                module, absl::make_unique<SequentialHloOrdering>(schedule),
1894                ByteSizeOf,
1895                [alignment](LogicalBuffer::Color) { return alignment; },
1896                /*allow_input_output_aliasing=*/false,
1897                /*allocate_buffers_for_constants=*/true)
1898         .ConsumeValueOrDie();
1899   }
1900 
ByteSizeOf(const BufferValue & buffer)1901   static int64 ByteSizeOf(const BufferValue& buffer) {
1902     return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
1903   }
1904 
1905   Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
1906   Shape loop_state_shape_ =
1907       ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
1908 };
1909 
RunCopyInsertion(HloModule * module)1910 static void RunCopyInsertion(HloModule* module) {
1911   CopyInsertion copy_insertion;
1912   EXPECT_IS_OK(copy_insertion.Run(module).status());
1913 }
1914 
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)1915 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
1916   auto module = CreateNewVerifiedModule();
1917   auto builder = HloComputation::Builder("entry");
1918 
1919   auto input0 = builder.AddInstruction(
1920       HloInstruction::CreateParameter(0, data_shape_, "input0"));
1921   auto weights0 = builder.AddInstruction(
1922       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
1923   auto weights1 = builder.AddInstruction(
1924       HloInstruction::CreateParameter(2, data_shape_, "weights1"));
1925 
1926   auto zero = builder.AddInstruction(
1927       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1928   auto output0 = builder.AddInstruction(
1929       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
1930   auto output1 = builder.AddInstruction(
1931       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
1932 
1933   auto cond0 =
1934       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
1935   auto body0 =
1936       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
1937 
1938   auto tuple0 = builder.AddInstruction(
1939       HloInstruction::CreateTuple({input0, weights0, output0}));
1940   auto while0 = builder.AddInstruction(
1941       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
1942 
1943   auto cond1 =
1944       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
1945   auto body1 =
1946       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
1947   auto input1 = builder.AddInstruction(
1948       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
1949   auto tuple1 = builder.AddInstruction(
1950       HloInstruction::CreateTuple({input1, weights1, output1}));
1951   auto while1 = builder.AddInstruction(
1952       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
1953 
1954   module->AddEntryComputation(builder.Build());
1955   RunCopyInsertion(module.get());
1956   auto assignment = RunBufferAssignment(module.get());
1957 
1958   // Verify 'input0' and read-only use while0{0} alias.
1959   EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
1960             assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
1961   // Verify 'weights0' and read-only use while0{1} alias.
1962   EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
1963             assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
1964   // Verify 'while0{2}' and read-only use while1{0} alias.
1965   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
1966             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
1967   // Verify 'weights1' and read-only use while1{1} alias.
1968   EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
1969             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
1970 }
1971 
1972 // Tests that two colocated buffer sets are not merged if an entry parameter
1973 // buffer belongs to either of the colocation sets (b/73267882).
1974 //
1975 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
1976 //
1977 // %while.0 body just forwards the init value, so the loop carried variable
1978 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)1979 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
1980   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
1981 
1982   const char* module_str = R"(
1983 HloModule test_module
1984 
1985 %cond.v0 {
1986   %param = s32[] parameter(0)
1987   ROOT %constant = pred[] constant(true)
1988 }
1989 
1990 %cond.v1 {
1991   %param.0 = s32[] parameter(0)
1992   ROOT %constant.0 = pred[] constant(true)
1993 }
1994 
1995 %body.v0 {
1996   ROOT %param.1 = s32[] parameter(0)
1997 }
1998 
1999 %body.v1 {
2000   %param.2 = s32[] parameter(0)
2001   ROOT add = s32[] add(%param.2, %param.2)
2002 }
2003 
2004 ENTRY %test_module {
2005   %param.3 = s32[] parameter(0)
2006   %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2007   %mul = s32[] multiply(%while.0, %while.0)
2008   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2009   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2010 })";
2011 
2012   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2013 
2014   // Run CopyInsertion and check if the graph constructed above doesn't need
2015   // any copies inserted for BufferAssignment to run.
2016   int64 instruction_count = m->instruction_count();
2017   CopyInsertion copy_insertion;
2018   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2019   ASSERT_EQ(instruction_count, m->instruction_count());
2020 
2021   // Get the instructions in the module.
2022   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2023   const HloInstruction* param =
2024       m->entry_computation()->parameter_instruction(0);
2025   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2026   const HloInstruction* while1 = bcast->operand(0);
2027   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2028   const HloInstruction* while0 = while1->operand(0)->operand(0);
2029   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2030 
2031   // Run buffer assignment.
2032   auto assignment = RunBufferAssignment(m.get());
2033   TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2034                           assignment->GetUniqueSlice(param, {}));
2035   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2036                           assignment->GetUniqueSlice(while0, {}));
2037   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2038                           assignment->GetUniqueSlice(while1, {}));
2039 
2040   // The parameter slice is part of the while0's colocation set (init value),
2041   // but not merged into the while1's colocation set.
2042   EXPECT_EQ(slice_param, slice_while0);
2043   EXPECT_NE(slice_param, slice_while1);
2044 }
2045 
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2046 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2047   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2048 
2049   const char* module_str = R"(
2050 HloModule test_module
2051 
2052 %cond.v0 {
2053   %param = s32[] parameter(0)
2054   ROOT %constant = pred[] constant(true)
2055 }
2056 
2057 %cond.v1 {
2058   %param.0 = s32[] parameter(0)
2059   ROOT %constant.0 = pred[] constant(true)
2060 }
2061 
2062 %body.v0 {
2063   ROOT %param.1 = s32[] parameter(0)
2064 }
2065 
2066 %body.v1 {
2067   %param.2 = s32[] parameter(0)
2068   ROOT add = s32[] add(%param.2, %param.2)
2069 }
2070 
2071 ENTRY %test_module {
2072   %constant.42 = s32[] constant(42)
2073   %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2074   %mul = s32[] multiply(%while.0, %while.0)
2075   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2076   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2077 })";
2078 
2079   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2080 
2081   // Run CopyInsertion and check if the graph constructed above doesn't need
2082   // any copies inserted for BufferAssignment to run.
2083   int64 instruction_count = m->instruction_count();
2084   CopyInsertion copy_insertion;
2085   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2086   ASSERT_EQ(instruction_count, m->instruction_count());
2087 
2088   // Get the instructions in the module.
2089   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2090   const HloInstruction* constant =
2091       m->entry_computation()->GetInstructionWithName("constant.42");
2092   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2093   const HloInstruction* while1 = bcast->operand(0);
2094   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2095   const HloInstruction* while0 = while1->operand(0)->operand(0);
2096   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2097 
2098   // Run buffer assignment.
2099   auto assignment = RunBufferAssignment(m.get());
2100   TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2101                           assignment->GetUniqueSlice(constant, {}));
2102   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2103                           assignment->GetUniqueSlice(while0, {}));
2104   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2105                           assignment->GetUniqueSlice(while1, {}));
2106 
2107   // The constant slice is part of the while0's colocation set (init value), but
2108   // not merged into the while1's colocation set.
2109   EXPECT_EQ(slice_constant, slice_while0);
2110   EXPECT_NE(slice_constant, slice_while1);
2111 }
2112 
2113 // Tests that the colocated buffers for while instructions are properly assigned
2114 // during buffer assignment such that the result tuple elements are not assigned
2115 // to the same buffer.
2116 //
2117 // %infeed --> %while.0 --> %while.1 --+
2118 //                                     +-- %tuple
2119 //   %zero -->   %add   --> %while.2 --+
2120 //
2121 // Execution Order:
2122 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2123 //
2124 // The HLO computation used in this test requires specific ordering to expose
2125 // the bug (b/72496031). During buffer assignment, the visitation order of
2126 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2127 // assignment was coalescing the colocated buffers for all 3 while instructions,
2128 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2129 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2130   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2131 
2132   // Builds a condition computation: x -> x < 4
2133   auto build_cond = [&]() {
2134     auto builder = HloComputation::Builder("cond");
2135     auto const4 = builder.AddInstruction(
2136         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2137     auto param =
2138         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2139     builder.AddInstruction(
2140         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2141                                       const4, ComparisonDirection::kLt));
2142     return builder.Build();
2143   };
2144 
2145   // Builds a body computation: x -> x + 9
2146   auto build_body = [&]() {
2147     auto builder = HloComputation::Builder("body");
2148     auto const9 = builder.AddInstruction(
2149         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2150     auto param =
2151         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2152     builder.AddInstruction(
2153         HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2154     return builder.Build();
2155   };
2156 
2157   // Build the entry computation as described in the comment above.
2158   auto module = CreateNewVerifiedModule();
2159   auto builder = HloComputation::Builder("entry");
2160 
2161   auto token = builder.AddInstruction(HloInstruction::CreateToken());
2162   auto infeed =
2163       builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2164   auto infeed_data = builder.AddInstruction(
2165       HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2166   auto cond0 = module->AddEmbeddedComputation(build_cond());
2167   auto body0 = module->AddEmbeddedComputation(build_body());
2168   auto while0 = builder.AddInstruction(
2169       HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2170 
2171   auto cond1 = module->AddEmbeddedComputation(build_cond());
2172   auto body1 = module->AddEmbeddedComputation(build_body());
2173   auto while1 = builder.AddInstruction(
2174       HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2175 
2176   auto zero = builder.AddInstruction(
2177       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2178   auto add = builder.AddInstruction(
2179       HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2180   auto cond2 = module->AddEmbeddedComputation(build_cond());
2181   auto body2 = module->AddEmbeddedComputation(build_body());
2182   auto while2 = builder.AddInstruction(
2183       HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2184 
2185   auto tuple =
2186       builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2187   module->AddEntryComputation(builder.Build());
2188 
2189   // Run CopyInsertion and check if the graph constructed above doesn't need
2190   // any copies inserted for BufferAssignment to run.
2191   int64 instruction_count = module->instruction_count();
2192   CopyInsertion copy_insertion;
2193   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2194   ASSERT_EQ(instruction_count, module->instruction_count());
2195 
2196   // Create a sequential order among all the instructions in the entry
2197   // computation, since the issue this test stresses depends on the order the
2198   // nodes are traversed during BufferAssignment.
2199   TF_ASSERT_OK_AND_ASSIGN(
2200       HloSchedule schedule,
2201       ScheduleModule(module.get(), [](const BufferValue& buffer) {
2202         return ShapeUtil::ByteSizeOf(buffer.shape(),
2203                                      /*pointer_size=*/sizeof(void*));
2204       }));
2205   schedule.set_sequence(
2206       module->entry_computation(),
2207       {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2208   TF_ASSERT_OK(schedule.Verify());
2209 
2210   TF_ASSERT_OK_AND_ASSIGN(
2211       auto assignment,
2212       BufferAssigner::Run(
2213           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2214           backend().compiler()->BufferSizeBytesFunction(),
2215           [](LogicalBuffer::Color) { return 1; },
2216           /*allow_input_output_aliasing=*/false,
2217           /*allocate_buffers_for_constants=*/true));
2218 
2219   // The result tuple elements must be assigned with different buffers.
2220   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2221   TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2222   EXPECT_NE(slice0, slice1);
2223 
2224   // while0 and while1 result buffers must be equal to slice1.
2225   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2226                           assignment->GetUniqueSlice(while0, {}));
2227   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2228                           assignment->GetUniqueSlice(while1, {}));
2229   EXPECT_EQ(slice1, slice_while0);
2230   EXPECT_EQ(slice1, slice_while1);
2231 
2232   // while2 result buffer must be equal to slice0.
2233   TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2234                           assignment->GetUniqueSlice(while2, {}));
2235   EXPECT_EQ(slice0, slice_while2);
2236 }
2237 
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2238 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2239   auto module = CreateNewVerifiedModule();
2240   auto builder = HloComputation::Builder("entry");
2241 
2242   auto input0 = builder.AddInstruction(
2243       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2244   auto weights0 = builder.AddInstruction(
2245       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2246 
2247   auto zero = builder.AddInstruction(
2248       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2249   auto output0 = builder.AddInstruction(
2250       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2251 
2252   auto cond0 =
2253       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2254   auto body0 =
2255       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2256 
2257   auto tuple0 = builder.AddInstruction(
2258       HloInstruction::CreateTuple({input0, weights0, output0}));
2259   auto while0 = builder.AddInstruction(
2260       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2261 
2262   auto cond1 =
2263       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2264   auto body1 =
2265       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2266 
2267   auto while1 = builder.AddInstruction(
2268       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2269 
2270   module->AddEntryComputation(builder.Build());
2271   RunCopyInsertion(module.get());
2272   auto assignment = RunBufferAssignment(module.get());
2273 
2274   // while0 and while1 buffers should be completely aligned.
2275   EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2276             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2277   EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2278             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2279   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2280             assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2281 }
2282 
TEST_F(BufferAssignmentTest,TwoCalls)2283 TEST_F(BufferAssignmentTest, TwoCalls) {
2284   auto module = CreateNewVerifiedModule();
2285   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2286   HloComputation* sub_computation;
2287   {
2288     auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2289     auto param = builder.AddInstruction(
2290         HloInstruction::CreateParameter(0, r0f32, "param"));
2291     auto constant1 = builder.AddInstruction(
2292         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2293     auto add = builder.AddInstruction(
2294         HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2295     sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2296   }
2297   auto builder = HloComputation::Builder(TestName());
2298   auto constant2 = builder.AddInstruction(
2299       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2300   auto constant3 = builder.AddInstruction(
2301       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2302   auto call1 = builder.AddInstruction(
2303       HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2304   auto call2 = builder.AddInstruction(
2305       HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2306   auto add1 = builder.AddInstruction(
2307       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2308   auto add2 = builder.AddInstruction(
2309       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2310   module->AddEntryComputation(builder.Build(add2));
2311 
2312   {
2313     FlattenCallGraph flatten;
2314     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2315     EXPECT_TRUE(result);
2316     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2317   }
2318 
2319   RunCopyInsertion(module.get());
2320   auto assignment = RunBufferAssignment(module.get());
2321 
2322   EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2323 }
2324 
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2325 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2326   const char* hlo_text = R"(
2327 HloModule CallParamCoAllocation
2328 
2329 Callee {
2330   param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2331   param1 = s32[20] parameter(1)
2332   ROOT constant = f32[] constant(1)
2333 }
2334 
2335 ENTRY Main {
2336   entry_param0 = f32[100] parameter(0)
2337   entry_param1 = s32[20]  parameter(1)
2338   custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2339   call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2340   ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2341 }
2342 )";
2343 
2344   HloModuleConfig config;
2345   config.set_debug_options(GetDebugOptionsFromFlags());
2346   TF_ASSERT_OK_AND_ASSIGN(auto m,
2347                           ParseAndReturnVerifiedModule(hlo_text, config));
2348 
2349   auto buffers = RunBufferAssignment(m.get());
2350 
2351   HloComputation* main = m->entry_computation();
2352   HloComputation* callee = m->GetComputationWithName("Callee");
2353   EXPECT_NE(callee, nullptr);
2354 
2355   HloInstruction* param0 = callee->parameter_instruction(0);
2356   HloInstruction* param1 = callee->parameter_instruction(1);
2357 
2358   HloInstruction* entry_param0 = main->parameter_instruction(0);
2359   HloInstruction* entry_param1 = main->parameter_instruction(1);
2360   HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2361 
2362   EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2363             GetAllocation(*buffers, param0, {0}));
2364   EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2365             GetAllocation(*buffers, param1, {}));
2366 
2367   EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2368             GetAllocation(*buffers, param0, {1}));
2369   EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2370             GetAllocation(*buffers, param0, {1, 0}));
2371   EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2372             GetAllocation(*buffers, param0, {1, 1}));
2373 }
2374 
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2375 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2376   auto module = CreateNewVerifiedModule();
2377   auto builder = HloComputation::Builder(TestName());
2378 
2379   auto zero = builder.AddInstruction(
2380       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2381   auto one = builder.AddInstruction(
2382       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2383 
2384   auto input0 = builder.AddInstruction(
2385       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2386   auto weights0 = builder.AddInstruction(
2387       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2388   auto output0 = builder.AddInstruction(
2389       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2390 
2391   auto input1 = builder.AddInstruction(
2392       HloInstruction::CreateParameter(2, data_shape_, "input1"));
2393   auto weights1 = builder.AddInstruction(
2394       HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2395   auto output1 = builder.AddInstruction(
2396       HloInstruction::CreateBroadcast(data_shape_, one, {}));
2397 
2398   auto cond =
2399       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2400   auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2401 
2402   auto tuple0 = builder.AddInstruction(
2403       HloInstruction::CreateTuple({input0, weights0, output0}));
2404   auto tuple1 = builder.AddInstruction(
2405       HloInstruction::CreateTuple({input1, weights1, output1}));
2406 
2407   auto while0 = builder.AddInstruction(
2408       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2409   auto while1 = builder.AddInstruction(
2410       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2411 
2412   auto gte0 = builder.AddInstruction(
2413       HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2414   auto gte1 = builder.AddInstruction(
2415       HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2416   auto root_add = builder.AddInstruction(
2417       HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2418 
2419   module->AddEntryComputation(builder.Build());
2420 
2421   {
2422     FlattenCallGraph flatten;
2423     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2424     EXPECT_TRUE(result);
2425   }
2426 
2427   RunCopyInsertion(module.get());
2428 
2429   HloSchedule schedule =
2430       ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2431 
2432   // To trigger b/38494731, we want a specific Hlo schedule for the
2433   // root computation, so we overwrite that entry with a manually
2434   // crafted sequence.
2435   schedule.set_sequence(
2436       module->entry_computation(),
2437       {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2438        input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2439        gte0, gte1, root_add});
2440 
2441   // If this ASSERT fails, we constructed a bogus sequence above and this test
2442   // itself is buggy.
2443   TF_ASSERT_OK(schedule.Verify());
2444 
2445   auto assignment =
2446       BufferAssigner::Run(
2447           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2448           ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2449           /*allow_input_output_aliasing=*/false,
2450           /*allocate_buffers_for_constants=*/true)
2451           .ConsumeValueOrDie();
2452 
2453   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2454 }
2455 
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2456 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2457   auto module = CreateNewVerifiedModule();
2458   auto builder = HloComputation::Builder("entry");
2459 
2460   auto input0 = builder.AddInstruction(
2461       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2462   auto weights0 = builder.AddInstruction(
2463       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2464 
2465   auto zero = builder.AddInstruction(
2466       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2467   auto output0 = builder.AddInstruction(
2468       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2469   auto output1 = builder.AddInstruction(
2470       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2471 
2472   auto cond0 =
2473       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2474   auto body0 =
2475       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2476 
2477   auto tuple0 = builder.AddInstruction(
2478       HloInstruction::CreateTuple({input0, weights0, output0}));
2479   auto while0 = builder.AddInstruction(
2480       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2481 
2482   // Get output of 'while0' and feed as input to 'while1'.
2483   auto while0_out = builder.AddInstruction(
2484       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2485 
2486   auto cond1 =
2487       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2488   auto body1 =
2489       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2490 
2491   auto tuple1 = builder.AddInstruction(
2492       HloInstruction::CreateTuple({while0_out, weights0, output1}));
2493   auto while1 = builder.AddInstruction(
2494       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2495 
2496   // Get output of 'while1' so that it is live out of computation.
2497   auto while1_out = builder.AddInstruction(
2498       HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2499 
2500   module->AddEntryComputation(builder.Build());
2501   RunCopyInsertion(module.get());
2502   auto assignment = RunBufferAssignment(module.get());
2503   // Get BufferAllocation for root instruction.
2504   auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2505                          .ConsumeValueOrDie()
2506                          .allocation();
2507   // Test that root instruction allocation is live out.
2508   EXPECT_TRUE(root_alloc->maybe_live_out());
2509   // Test that root instruction allocation is not an entry parameter.
2510   EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2511 }
2512 
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2513 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2514   const char* const hlo_string = R"(
2515 HloModule test
2516 
2517 while_body {
2518   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2519   constant.1 = f32[] constant(0)
2520   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2521   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2522   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2523   constant.2 = s32[] constant(128)
2524   add.5 = s32[] add(get-tuple-element.3, constant.2)
2525   constant.3 = s32[] constant(0)
2526   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2527   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2528   ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2529 }
2530 
2531 while_condition {
2532   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2533   get-tuple-element = s32[] get-tuple-element(state), index=0
2534   get-tuple-element.1 = s32[] constant(3)
2535   ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2536 }
2537 
2538 ENTRY entry_computation {
2539   constant.7 = s32[] constant(0)
2540   copy.1 = s32[] copy(constant.7)
2541   constant.6 = f32[] constant(0)
2542   broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2543   tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2544   while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2545   ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2546 }
2547 
2548 )";
2549   auto module_or_status =
2550       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
2551   auto module = module_or_status.ConsumeValueOrDie();
2552 
2553   RunCopyInsertion(module.get());
2554   auto assignment = RunBufferAssignment(module.get());
2555   // Get BufferAllocation for root instruction.
2556   auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2557   auto dus9_alloc_slice =
2558       assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2559   auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2560   auto dus5_alloc_slice =
2561       assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2562   // Test that the two dynamic-update-slice ops share the same allocation slice.
2563   EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2564   EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2565 }
2566 }  // namespace
2567 }  // namespace xla
2568