1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 class MinimumMemoryForSequenceTest : public HloTestBase {};
40 
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
42   auto module = CreateNewVerifiedModule();
43   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
44   const Shape tuple_shape =
45       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
46 
47   auto cond_builder = HloComputation::Builder("WhileCond");
48   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
49   HloInstruction* cond_param = cond_builder.AddInstruction(
50       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
51   HloInstruction* cond_iter = cond_builder.AddInstruction(
52       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
53   HloInstruction* cond_data = cond_builder.AddInstruction(
54       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
55   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
56   HloInstruction* cond_lt = cond_builder.AddInstruction(
57       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
58                                     cond_data, ComparisonDirection::kLt));
59   HloComputation* cond_computation =
60       module->AddEmbeddedComputation(cond_builder.Build());
61 
62   auto body_builder = HloComputation::Builder("WhileBody");
63   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
64   HloInstruction* body_param = body_builder.AddInstruction(
65       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
66   HloComputation* body_computation =
67       module->AddEmbeddedComputation(body_builder.Build());
68 
69   auto builder = HloComputation::Builder(TestName());
70   // Entry params: 8 bytes (4 bytes per param), TOTAL=8
71   HloInstruction* iter = builder.AddInstruction(
72       HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
73   HloInstruction* data = builder.AddInstruction(
74       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
75   // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
76   HloInstruction* tuple =
77       builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
78   // While: 8 bytes (4 bytes per element), TOTAL=32
79   // Both cond and body use a max of 24 bytes, TOTAL=56
80   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
81       tuple_shape, cond_computation, body_computation, tuple));
82   HloComputation* entry_computation =
83       module->AddEntryComputation(builder.Build());
84 
85   auto size_fn = [](const BufferValue& buffer) {
86     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
87   };
88 
89   HloSchedule schedule(module.get());
90   schedule.set_sequence(cond_computation,
91                         {cond_param, cond_iter, cond_data, cond_lt});
92   schedule.set_sequence(body_computation, {body_param});
93   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
94   TF_ASSERT_OK(schedule.Verify());
95 
96   EXPECT_EQ(
97       56,
98       HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
99 }
100 
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
102   // HloModule SubcomputationAccounting
103 
104   // %WhileBody (body_param: f32[4]) -> f32[4] {
105   //   %body_param = f32[4]{0} parameter(0)
106   //   %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
107   //   ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
108   //   %constant.1)
109   // }
110 
111   // %WhileCond (cond_param: f32[4]) -> pred[] {
112   //   %cond_param = f32[4]{0} parameter(0)
113   //   %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
114   //   %reshape = f32[] reshape(f32[1]{0} %slice)
115   //   %constant = f32[] constant(0)
116   //   ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
117   //   direction=NE
118   // }
119 
120   // ENTRY %SubcomputationAccounting () -> f32[2,4] {
121   //   %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
122   //   3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
123   //   %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
124   //   1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
125   //   condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
126   //   broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
127   //   add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
128   // }
129 
130   auto module = CreateNewVerifiedModule();
131   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
132   const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
133   const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
134 
135   // reshape(slice(param)) != 0
136   // Needs 5 bytes
137   auto cond_builder = HloComputation::Builder("WhileCond");
138   HloInstruction* cond_param = cond_builder.AddInstruction(
139       HloInstruction::CreateParameter(0, r1f32, "cond_param"));
140   HloInstruction* slice =
141       cond_builder.AddInstruction(HloInstruction::CreateSlice(
142           ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
143   HloInstruction* reshape =
144       cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
145   HloInstruction* zero = cond_builder.AddInstruction(
146       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
147   HloInstruction* cond_comparison = cond_builder.AddInstruction(
148       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
149                                     zero, ComparisonDirection::kNe));
150   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
151 
152   // param - 1
153   // Needs 16 bytes
154   auto body_builder = HloComputation::Builder("WhileBody");
155   HloInstruction* body_param = body_builder.AddInstruction(
156       HloInstruction::CreateParameter(0, r1f32, "body_param"));
157   HloInstruction* one_vector =
158       body_builder.AddInstruction(HloInstruction::CreateConstant(
159           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
160   HloInstruction* subtract =
161       body_builder.AddInstruction(HloInstruction::CreateBinary(
162           r1f32, HloOpcode::kSubtract, body_param, one_vector));
163   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
164 
165   // transpose(matrix) + bcast(while)
166   auto builder = HloComputation::Builder(TestName());
167   HloInstruction* while_init =
168       builder.AddInstruction(HloInstruction::CreateConstant(
169           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
170   // Creates 16 bytes, ignoring subcomputations
171   HloInstruction* while_loop =
172       builder.AddInstruction(HloInstruction::CreateWhile(
173           r1f32, cond_computation, body_computation, while_init));
174 
175   // Creates 32 bytes and frees 16
176   HloInstruction* bcast = builder.AddInstruction(
177       HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
178 
179   HloInstruction* matrix = builder.AddInstruction(
180       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
181           {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
182   // Creates 32 bytes
183   HloInstruction* transpose = builder.AddInstruction(
184       HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
185 
186   // Creates 32 bytes and frees 64
187   HloInstruction* add = builder.AddInstruction(
188       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
189 
190   auto entry_computation = module->AddEntryComputation(builder.Build());
191 
192   HloSchedule schedule(module.get());
193   std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
194                                            cond_comparison};
195   std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
196                                                  subtract};
197   std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
198                                                  matrix,     transpose,  add};
199   schedule.set_sequence(cond_computation, cond_vec);
200   schedule.set_sequence(body_computation, while_body_vec);
201   schedule.set_sequence(entry_computation, entry_comp_vec);
202 
203   auto size_fn = [](const BufferValue& buffer) {
204     return ShapeUtil::ByteSizeOf(buffer.shape());
205   };
206   absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
207   memory_by_computation[cond_computation] = 5;
208   memory_by_computation[body_computation] = 16;
209   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
210       TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
211 
212   // HeapSimulator accounts for subcomputations. The output buffer is aliased,
213   // so we don't double count.
214   EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
215                     *entry_computation, schedule.sequence(entry_computation),
216                     *points_to_analysis, size_fn, &memory_by_computation)
217                     .ValueOrDie());
218 }
219 
220 const char kAlloc[] = "Alloc";
221 const char kFree[] = "Free";
222 const char kFinish[] = "Finish";
223 
224 // CallSequence records a sequence of Alloc/Free/Finish calls.
225 using CallSequence = std::vector<std::pair<string, const BufferValue*>>;
226 
227 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
228 class HeapCallRecorder : public HeapAlgorithm {
229  public:
HeapCallRecorder(CallSequence * calls)230   explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()231   ~HeapCallRecorder() override {}
232 
Alloc(const BufferValue * buffer,int64 size)233   void Alloc(const BufferValue* buffer, int64 size) override {
234     calls_->emplace_back(kAlloc, buffer);
235     // Instead of assigning a real offset, we set the cardinality of the Alloc
236     // call.  This isn't a valid assignment, but allows us to easily test for
237     // buffer sharing.
238     const int64 offset = result_.chunk_map.size();
239     result_.chunk_map.emplace(buffer, Chunk{offset, size});
240   }
Free(const BufferValue * buffer,int64 size)241   void Free(const BufferValue* buffer, int64 size) override {
242     calls_->emplace_back(kFree, buffer);
243   }
Finish()244   Result Finish() override {
245     calls_->emplace_back(kFinish, nullptr);
246     return result_;
247   }
248 
249  private:
250   CallSequence* calls_;
251   Result result_;
252 };
253 
254 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
255 // made to the underlying heap algorithm.  Tests compare the actual call
256 // sequence against an expected sequence.
257 class HeapSimulatorTracker {
258  public:
259   // Constructor for testing a single entry computation.
HeapSimulatorTracker(const string & name,std::unique_ptr<HloComputation> computation,const std::vector<HloInstruction * > & instruction_sequence)260   HeapSimulatorTracker(
261       const string& name, std::unique_ptr<HloComputation> computation,
262       const std::vector<HloInstruction*>& instruction_sequence) {
263     HloModuleConfig config;
264     module_ = absl::make_unique<HloModule>(name, config);
265     module_->AddEntryComputation(std::move(computation));
266     points_to_analysis_ =
267         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
268     // Since we're only tracking the sequence of Alloc/Free calls, the actual
269     // size of the buffers doesn't matter, so we always return 0.  We rely on
270     // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
271     // buffer id, for determinism in the tests.
272     auto zero_size = [](const BufferValue& buffer) { return 0; };
273     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
274         absl::make_unique<HeapCallRecorder>(&actual_calls_));
275     result_ =
276         HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
277                            HloInstructionSequence(instruction_sequence),
278                            *points_to_analysis_, zero_size)
279             .ConsumeValueOrDie();
280   }
281 
HeapSimulatorTracker(const string & name)282   explicit HeapSimulatorTracker(const string& name) {
283     HloModuleConfig config;
284     module_ = absl::make_unique<HloModule>(name, config);
285   }
286 
287   // Similar to the single entry computation constructor above, but runs the
288   // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)289   void RunWholeModule(
290       const std::vector<HloInstruction*>& full_module_sequence) {
291     points_to_analysis_ =
292         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
293 
294     // Construct the module sequence grouped by computation.
295     HloSchedule schedule(module_.get());
296     absl::flat_hash_map<const HloInstruction*, int> reverse_position;
297     for (int i = 0; i < full_module_sequence.size(); ++i) {
298       HloInstruction* instruction = full_module_sequence[i];
299       schedule.GetOrCreateSequence(instruction->parent())
300           .push_back(instruction);
301       reverse_position[instruction] = full_module_sequence.size() - i;
302     }
303 
304     // Hack the size_fn so that it returns a decreasing value as we step through
305     // the sequence. This lets us ensure the Alloc calls are in the sequence
306     // order. The Free calls are sorted by BufferValue.id, which is at least
307     // deterministic.
308     auto size_fn = [&reverse_position](const BufferValue& buffer) {
309       return reverse_position[buffer.instruction()];
310     };
311     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
312         absl::make_unique<HeapCallRecorder>(&actual_calls_));
313     result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
314                                  *points_to_analysis_, size_fn)
315                   .ConsumeValueOrDie();
316   }
317 
module()318   HloModule* module() { return module_.get(); }
319 
320   // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const321   const BufferValue* BufferAt(const HloInstruction* instruction,
322                               const ShapeIndex& index) const {
323     return points_to_analysis_->GetBufferDefinedAt(instruction, index)
324         .ConsumeValueOrDie();
325   }
326 
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)327   int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
328     const BufferValue* buffer = BufferAt(instruction, index);
329     return result_.chunk_map.at(buffer).offset;
330   }
331 
332   // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const333   void ExpectCallSequence(const CallSequence& expected) const {
334     EXPECT_EQ(expected, actual_calls_);
335   }
336 
337   // Ensures the buffers defined by the respective (instruction,index) pairs are
338   // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)339   void ExpectSharedBuffers(const HloInstruction* instruction_a,
340                            const ShapeIndex& index_a,
341                            const HloInstruction* instruction_b,
342                            const ShapeIndex& index_b) {
343     int64 offset_a = OffsetAt(instruction_a, index_a);
344     int64 offset_b = OffsetAt(instruction_b, index_b);
345     EXPECT_EQ(offset_a, offset_b);
346   }
347 
348  private:
349   std::unique_ptr<HloModule> module_;
350   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
351   CallSequence actual_calls_;
352   HeapSimulator::Result result_;
353 };
354 
355 class HeapSimulatorTest : public HloTestBase {
356  protected:
HeapSimulatorTest()357   HeapSimulatorTest() {}
~HeapSimulatorTest()358   ~HeapSimulatorTest() override {}
359 
360   // Shapes for use in the examples.
361   Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
362   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
363 };
364 
TEST_F(HeapSimulatorTest,ScalarConstant)365 TEST_F(HeapSimulatorTest, ScalarConstant) {
366   auto builder = HloComputation::Builder(TestName());
367   auto const0 = builder.AddInstruction(
368       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
369 
370   // Constants aren't assigned.  See b/32248867
371   HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
372   tracker.ExpectCallSequence({{kFinish, nullptr}});
373 }
374 
TEST_F(HeapSimulatorTest,OneParam)375 TEST_F(HeapSimulatorTest, OneParam) {
376   auto builder = HloComputation::Builder(TestName());
377   auto param0 = builder.AddInstruction(
378       HloInstruction::CreateParameter(0, f32scalar_, "param0"));
379 
380   // A single parameter which is also the output.
381   HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
382   tracker.ExpectCallSequence({
383       {kAlloc, tracker.BufferAt(param0, {})},
384       {kFree, tracker.BufferAt(param0, {})},
385       {kFinish, nullptr},
386   });
387 }
388 
TEST_F(HeapSimulatorTest,Multiply)389 TEST_F(HeapSimulatorTest, Multiply) {
390   auto builder = HloComputation::Builder(TestName());
391   auto paramA = builder.AddInstruction(
392       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
393   auto paramX = builder.AddInstruction(
394       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
395   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
396       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
397 
398   // We must keep all parameters and outputs.
399   HeapSimulatorTracker tracker(TestName(), builder.Build(),
400                                {paramA, paramX, mul});
401   tracker.ExpectCallSequence({
402       {kAlloc, tracker.BufferAt(paramA, {})},
403       {kAlloc, tracker.BufferAt(paramX, {})},
404       {kAlloc, tracker.BufferAt(mul, {})},
405       // All params and outputs are freed at the end.
406       {kFree, tracker.BufferAt(paramA, {})},
407       {kFree, tracker.BufferAt(paramX, {})},
408       {kFree, tracker.BufferAt(mul, {})},
409       {kFinish, nullptr},
410   });
411 }
412 
TEST_F(HeapSimulatorTest,MultiplyAdd)413 TEST_F(HeapSimulatorTest, MultiplyAdd) {
414   auto builder = HloComputation::Builder(TestName());
415   auto paramA = builder.AddInstruction(
416       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
417   auto paramX = builder.AddInstruction(
418       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
419   auto paramY = builder.AddInstruction(
420       HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
421   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
422       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
423   auto add = builder.AddInstruction(
424       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
425 
426   // The buffer for add is the output, and it's shared with the buffer for mul.
427   HeapSimulatorTracker tracker(TestName(), builder.Build(),
428                                {paramA, paramX, mul, paramY, add});
429   tracker.ExpectCallSequence({
430       {kAlloc, tracker.BufferAt(paramA, {})},
431       {kAlloc, tracker.BufferAt(paramX, {})},
432       {kAlloc, tracker.BufferAt(mul, {})},
433       {kAlloc, tracker.BufferAt(paramY, {})},
434       // All params and outputs are freed at the end.
435       {kFree, tracker.BufferAt(paramA, {})},
436       {kFree, tracker.BufferAt(paramX, {})},
437       {kFree, tracker.BufferAt(mul, {})},
438       {kFree, tracker.BufferAt(paramY, {})},
439       {kFinish, nullptr},
440   });
441   tracker.ExpectSharedBuffers(add, {}, mul, {});
442 }
443 
TEST_F(HeapSimulatorTest,BufferReusedOnce)444 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
445   HeapSimulatorTracker tracker(TestName());
446   auto builder = HloComputation::Builder(TestName());
447 
448   HloComputation::Builder fusion_builder("fusion");
449   {
450     HloComputation::Builder& builder = fusion_builder;
451     auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
452         /*parameter_number=*/0, f32vec4_, "A"));
453     auto exp = builder.AddInstruction(
454         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
455     auto neg = builder.AddInstruction(
456         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
457 
458     builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
459   }
460   auto fusion_computation =
461       tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
462   auto a_param = builder.AddInstruction(
463       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
464   auto neg = builder.AddInstruction(
465       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
466   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
467       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
468       HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
469   tracker.module()->AddEntryComputation(builder.Build());
470 
471   tracker.RunWholeModule({a_param, neg, fusion});
472 
473   auto neg_buffer = tracker.OffsetAt(neg, {});
474   int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
475   int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
476   // Only one buffer should be shared.
477   EXPECT_TRUE((neg_buffer == output_buffer_0) ^
478               (neg_buffer == output_buffer_1));
479 }
480 
TEST_F(HeapSimulatorTest,MultiplyDot)481 TEST_F(HeapSimulatorTest, MultiplyDot) {
482   auto builder = HloComputation::Builder(TestName());
483   auto paramA = builder.AddInstruction(
484       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
485   auto paramX = builder.AddInstruction(
486       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
487   auto paramY = builder.AddInstruction(
488       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
489   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
490       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
491   DotDimensionNumbers dot_dnums;
492   dot_dnums.add_lhs_contracting_dimensions(1);
493   dot_dnums.add_rhs_contracting_dimensions(0);
494   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
495       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
496 
497   // The buffer for dot is the output, and it cannot be shared with the buffer
498   // for mul, since dot isn't elementwise.
499   HeapSimulatorTracker tracker(TestName(), builder.Build(),
500                                {paramA, paramX, mul, paramY, dot});
501   tracker.ExpectCallSequence({
502       {kAlloc, tracker.BufferAt(paramA, {})},
503       {kAlloc, tracker.BufferAt(paramX, {})},
504       {kAlloc, tracker.BufferAt(mul, {})},
505       {kAlloc, tracker.BufferAt(paramY, {})},
506       {kAlloc, tracker.BufferAt(dot, {})},
507       // All params and outputs are freed at the end.
508       {kFree, tracker.BufferAt(paramA, {})},
509       {kFree, tracker.BufferAt(paramX, {})},
510       {kFree, tracker.BufferAt(mul, {})},
511       {kFree, tracker.BufferAt(paramY, {})},
512       {kFree, tracker.BufferAt(dot, {})},
513       {kFinish, nullptr},
514   });
515 }
516 
TEST_F(HeapSimulatorTest,MultiplyDotAdd)517 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
518   auto builder = HloComputation::Builder(TestName());
519   auto paramA = builder.AddInstruction(
520       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
521   auto paramX = builder.AddInstruction(
522       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
523   auto paramY = builder.AddInstruction(
524       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
525   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
526       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
527   DotDimensionNumbers dot_dnums;
528   dot_dnums.add_lhs_contracting_dimensions(1);
529   dot_dnums.add_rhs_contracting_dimensions(0);
530   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
531       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
532   auto add = builder.AddInstruction(
533       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
534 
535   // The buffer for add is the output, and it's shared with the buffer for dot.
536   HeapSimulatorTracker tracker(TestName(), builder.Build(),
537                                {paramA, paramX, mul, paramY, dot, add});
538   tracker.ExpectCallSequence({
539       {kAlloc, tracker.BufferAt(paramA, {})},
540       {kAlloc, tracker.BufferAt(paramX, {})},
541       {kAlloc, tracker.BufferAt(mul, {})},
542       {kAlloc, tracker.BufferAt(paramY, {})},
543       {kAlloc, tracker.BufferAt(dot, {})},
544       // All params and outputs are freed at the end.
545       {kFree, tracker.BufferAt(paramA, {})},
546       {kFree, tracker.BufferAt(paramX, {})},
547       {kFree, tracker.BufferAt(mul, {})},
548       {kFree, tracker.BufferAt(paramY, {})},
549       {kFree, tracker.BufferAt(dot, {})},
550       {kFinish, nullptr},
551   });
552   tracker.ExpectSharedBuffers(add, {}, dot, {});
553 }
554 
TEST_F(HeapSimulatorTest,MultiplyDotDot)555 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
556   auto builder = HloComputation::Builder(TestName());
557   auto paramA = builder.AddInstruction(
558       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
559   auto paramX = builder.AddInstruction(
560       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
561   auto paramY = builder.AddInstruction(
562       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
563   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
564       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
565   DotDimensionNumbers dot_dnums;
566   dot_dnums.add_lhs_contracting_dimensions(1);
567   dot_dnums.add_rhs_contracting_dimensions(0);
568   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
569       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
570   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
571       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
572 
573   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
574   // for mul is freed before the end, since it's no longer used after dot0
575   // finishes.
576   HeapSimulatorTracker tracker(TestName(), builder.Build(),
577                                {paramA, paramX, mul, paramY, dot0, dot1});
578   tracker.ExpectCallSequence({
579       {kAlloc, tracker.BufferAt(paramA, {})},
580       {kAlloc, tracker.BufferAt(paramX, {})},
581       {kAlloc, tracker.BufferAt(mul, {})},
582       {kAlloc, tracker.BufferAt(paramY, {})},
583       {kAlloc, tracker.BufferAt(dot0, {})},
584       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
585       {kAlloc, tracker.BufferAt(dot1, {})},
586       // All params and outputs are freed at the end.
587       {kFree, tracker.BufferAt(paramA, {})},
588       {kFree, tracker.BufferAt(paramX, {})},
589       {kFree, tracker.BufferAt(paramY, {})},
590       {kFree, tracker.BufferAt(dot0, {})},
591       {kFree, tracker.BufferAt(dot1, {})},
592       {kFinish, nullptr},
593   });
594 }
595 
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)596 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
597   auto builder = HloComputation::Builder(TestName());
598   auto paramA = builder.AddInstruction(
599       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
600   auto paramX = builder.AddInstruction(
601       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
602   auto paramY = builder.AddInstruction(
603       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
604   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
605       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
606   DotDimensionNumbers dot_dnums;
607   dot_dnums.add_lhs_contracting_dimensions(1);
608   dot_dnums.add_rhs_contracting_dimensions(0);
609   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
610       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
611   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
612       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
613   auto tuple =
614       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
615 
616   // The buffers for dot0, dot1 and tuple are the output.  No buffers can be
617   // shared.  The buffer for mul is freed before the end, since it's no longer
618   // used after dot0 finishes.
619   HeapSimulatorTracker tracker(
620       TestName(), builder.Build(),
621       {paramA, paramX, mul, paramY, dot0, dot1, tuple});
622   tracker.ExpectCallSequence({
623       {kAlloc, tracker.BufferAt(paramA, {})},
624       {kAlloc, tracker.BufferAt(paramX, {})},
625       {kAlloc, tracker.BufferAt(mul, {})},
626       {kAlloc, tracker.BufferAt(paramY, {})},
627       {kAlloc, tracker.BufferAt(dot0, {})},
628       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
629       {kAlloc, tracker.BufferAt(dot1, {})},
630       {kAlloc, tracker.BufferAt(tuple, {})},
631       // All params and outputs are freed at the end.
632       {kFree, tracker.BufferAt(paramA, {})},
633       {kFree, tracker.BufferAt(paramX, {})},
634       {kFree, tracker.BufferAt(paramY, {})},
635       {kFree, tracker.BufferAt(dot0, {})},
636       {kFree, tracker.BufferAt(dot1, {})},
637       {kFree, tracker.BufferAt(tuple, {})},
638       {kFinish, nullptr},
639   });
640 }
641 
TEST_F(HeapSimulatorTest,IndependentTupleElements)642 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
643   auto builder = HloComputation::Builder(TestName());
644   auto paramA = builder.AddInstruction(
645       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
646   auto paramB = builder.AddInstruction(
647       HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
648   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
649       f32scalar_, HloOpcode::kMultiply, paramA, paramB));
650   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
651       f32scalar_, HloOpcode::kAdd, paramA, paramB));
652   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
653   auto element0 = builder.AddInstruction(
654       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
655   auto broadcast = builder.AddInstruction(
656       HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
657   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
658       f32scalar_, HloOpcode::kSubtract, paramA, paramB));
659   auto element1 = builder.AddInstruction(
660       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
661   auto output = builder.AddInstruction(
662       HloInstruction::CreateTuple({broadcast, sub, element1}));
663 
664   HeapSimulatorTracker tracker(TestName(), builder.Build(),
665                                {paramA, paramB, mul, add, tuple, element0,
666                                 broadcast, sub, element1, output});
667   tracker.ExpectCallSequence({
668       {kAlloc, tracker.BufferAt(paramA, {})},
669       {kAlloc, tracker.BufferAt(paramB, {})},
670       {kAlloc, tracker.BufferAt(mul, {})},
671       {kAlloc, tracker.BufferAt(add, {})},
672       {kAlloc, tracker.BufferAt(tuple, {})},
673       {kAlloc, tracker.BufferAt(broadcast, {})},
674       // The mul can be freed right after the broadcast happens, even though
675       // The other GetTupleElement is still alive.
676       {kFree, tracker.BufferAt(mul, {})},
677       {kAlloc, tracker.BufferAt(sub, {})},
678       // The temporary tuple is now dead.
679       {kFree, tracker.BufferAt(tuple, {})},
680       {kAlloc, tracker.BufferAt(output, {})},
681       // All params and outputs are freed at the end.
682       {kFree, tracker.BufferAt(paramA, {})},
683       {kFree, tracker.BufferAt(paramB, {})},
684       {kFree, tracker.BufferAt(add, {})},
685       {kFree, tracker.BufferAt(broadcast, {})},
686       {kFree, tracker.BufferAt(sub, {})},
687       {kFree, tracker.BufferAt(output, {})},
688       {kFinish, nullptr},
689   });
690 }
691 
TEST_F(HeapSimulatorTest,WholeModule)692 TEST_F(HeapSimulatorTest, WholeModule) {
693   HeapSimulatorTracker tracker(TestName());
694 
695   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
696   const Shape tuple_shape =
697       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
698 
699   auto cond_builder = HloComputation::Builder("WhileCond");
700   HloInstruction* cond_param = cond_builder.AddInstruction(
701       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
702   HloInstruction* cond_iter = cond_builder.AddInstruction(
703       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
704   HloInstruction* cond_data = cond_builder.AddInstruction(
705       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
706   HloInstruction* cond_lt = cond_builder.AddInstruction(
707       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
708                                     cond_data, ComparisonDirection::kLt));
709   HloComputation* cond_computation =
710       tracker.module()->AddEmbeddedComputation(cond_builder.Build());
711 
712   auto body_builder = HloComputation::Builder("WhileBody");
713   HloInstruction* body_param = body_builder.AddInstruction(
714       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
715   HloComputation* body_computation =
716       tracker.module()->AddEmbeddedComputation(body_builder.Build());
717 
718   auto builder = HloComputation::Builder(TestName());
719   HloInstruction* param = builder.AddInstruction(
720       HloInstruction::CreateParameter(0, tuple_shape, "param"));
721   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
722       tuple_shape, cond_computation, body_computation, param));
723   tracker.module()->AddEntryComputation(builder.Build());
724 
725   tracker.RunWholeModule(
726       {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
727   tracker.ExpectCallSequence({
728       // The entry computation param and while_op are allocated first.
729       {kAlloc, tracker.BufferAt(param, {})},
730       {kAlloc, tracker.BufferAt(param, {0})},
731       {kAlloc, tracker.BufferAt(param, {1})},
732       {kAlloc, tracker.BufferAt(while_op, {})},
733       {kAlloc, tracker.BufferAt(while_op, {0})},
734       {kAlloc, tracker.BufferAt(while_op, {1})},
735 
736       // Now the while body param is allocated and freed.
737       {kAlloc, tracker.BufferAt(body_param, {})},
738       {kAlloc, tracker.BufferAt(body_param, {0})},
739       {kAlloc, tracker.BufferAt(body_param, {1})},
740       {kFree, tracker.BufferAt(body_param, {})},
741       {kFree, tracker.BufferAt(body_param, {0})},
742       {kFree, tracker.BufferAt(body_param, {1})},
743 
744       // Now the while cond param is allocated. The GTE instructions just alias
745       // the param elements, so the param tuple can immediately be freed.
746       {kAlloc, tracker.BufferAt(cond_param, {})},
747       {kAlloc, tracker.BufferAt(cond_param, {0})},
748       {kAlloc, tracker.BufferAt(cond_param, {1})},
749       {kFree, tracker.BufferAt(cond_param, {})},
750 
751       // Now the final cond less-than buffer is allocated.
752       {kAlloc, tracker.BufferAt(cond_lt, {})},
753 
754       // The order of the remaining Free calls is based on the BufferValue.id,
755       // which is deterministic, but not obvious.
756       {kFree, tracker.BufferAt(param, {})},
757       {kFree, tracker.BufferAt(param, {0})},
758       {kFree, tracker.BufferAt(param, {1})},
759 
760       {kFree, tracker.BufferAt(while_op, {})},
761       {kFree, tracker.BufferAt(while_op, {0})},
762       {kFree, tracker.BufferAt(while_op, {1})},
763 
764       {kFree, tracker.BufferAt(cond_param, {0})},
765       {kFree, tracker.BufferAt(cond_param, {1})},
766       {kFree, tracker.BufferAt(cond_lt, {})},
767 
768       {kFinish, nullptr},
769   });
770 }
771 
772 // Base class for heap algorithm tests.
773 class HeapAlgorithmTestBase : public ::testing::Test {
774  protected:
HeapAlgorithmTestBase()775   HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
776     buffer_a_ = DummyBufferValue();
777     buffer_b_ = DummyBufferValue();
778     buffer_c_ = DummyBufferValue();
779     buffer_d_ = DummyBufferValue();
780     buffer_e_ = DummyBufferValue();
781     buffer_f_ = DummyBufferValue();
782     buffer_g_ = DummyBufferValue();
783     buffer_h_ = DummyBufferValue();
784     buffer_i_ = DummyBufferValue();
785   }
~HeapAlgorithmTestBase()786   ~HeapAlgorithmTestBase() override {}
787 
788   const BufferValue* buffer_a_;
789   const BufferValue* buffer_b_;
790   const BufferValue* buffer_c_;
791   const BufferValue* buffer_d_;
792   const BufferValue* buffer_e_;
793   const BufferValue* buffer_f_;
794   const BufferValue* buffer_g_;
795   const BufferValue* buffer_h_;
796   const BufferValue* buffer_i_;
797 
798  private:
799   // Create a dummy BufferValue to pass to the heap algorithm.
DummyBufferValue()800   const BufferValue* DummyBufferValue() {
801     const BufferValue::Id id = buffers_.size();
802     auto const0 = builder_.AddInstruction(
803         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
804     buffers_.emplace_back(
805         absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
806     return buffers_.back().get();
807   }
808 
809   HloComputation::Builder builder_;
810   std::vector<std::unique_ptr<BufferValue>> buffers_;
811 };
812 
813 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
814 
TEST_F(NoFragmentationStatsHeapTest,Empty)815 TEST_F(NoFragmentationStatsHeapTest, Empty) {
816   NoFragmentationStatsHeap heap;
817   EXPECT_EQ(0, heap.Finish().heap_size);
818 }
819 
TEST_F(NoFragmentationStatsHeapTest,Simple)820 TEST_F(NoFragmentationStatsHeapTest, Simple) {
821   NoFragmentationStatsHeap heap;
822   heap.Alloc(buffer_a_, 10);
823   heap.Alloc(buffer_b_, 20);
824   heap.Alloc(buffer_c_, 30);
825   heap.Alloc(buffer_d_, 30);
826   heap.Free(buffer_a_, 10);
827   heap.Free(buffer_b_, 20);
828   heap.Free(buffer_c_, 30);
829   heap.Free(buffer_d_, 30);
830   EXPECT_EQ(90, heap.Finish().heap_size);
831 }
832 
TEST_F(NoFragmentationStatsHeapTest,Mixed)833 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
834   NoFragmentationStatsHeap heap;
835   heap.Alloc(buffer_a_, 10);  // max: A
836 
837   heap.Alloc(buffer_b_, 20);  // max: A+B
838   heap.Free(buffer_b_, 20);
839 
840   heap.Alloc(buffer_c_, 30);  // max: A+C
841   heap.Free(buffer_c_, 30);
842 
843   heap.Alloc(buffer_d_, 5);  // max: A+C
844   heap.Free(buffer_d_, 5);
845 
846   heap.Free(buffer_a_, 10);
847   EXPECT_EQ(40, heap.Finish().heap_size);
848 }
849 
850 class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
851 
TEST_F(DecreasingSizeRunsHeapTest,Empty)852 TEST_F(DecreasingSizeRunsHeapTest, Empty) {
853   CallSequence call_sequence;
854   DecreasingSizeRunsHeap heap(
855       absl::make_unique<HeapCallRecorder>(&call_sequence));
856   heap.Finish();
857   EXPECT_EQ(call_sequence, CallSequence({
858                                {kFinish, nullptr},
859                            }));
860 }
861 
TEST_F(DecreasingSizeRunsHeapTest,Simple)862 TEST_F(DecreasingSizeRunsHeapTest, Simple) {
863   CallSequence call_sequence;
864   DecreasingSizeRunsHeap heap(
865       absl::make_unique<HeapCallRecorder>(&call_sequence));
866   heap.Alloc(buffer_a_, 10);
867   heap.Alloc(buffer_b_, 20);
868   heap.Alloc(buffer_c_, 30);
869   heap.Alloc(buffer_d_, 30);
870   heap.Free(buffer_a_, 10);
871   heap.Free(buffer_b_, 20);
872   heap.Free(buffer_c_, 30);
873   heap.Free(buffer_d_, 30);
874   heap.Finish();
875   // Runs of Allocs and Frees are sorted by decreasing size, with buffer id
876   // tiebreaker.
877   EXPECT_EQ(call_sequence, CallSequence({
878                                {kAlloc, buffer_c_},
879                                {kAlloc, buffer_d_},
880                                {kAlloc, buffer_b_},
881                                {kAlloc, buffer_a_},
882                                {kFree, buffer_c_},
883                                {kFree, buffer_d_},
884                                {kFree, buffer_b_},
885                                {kFree, buffer_a_},
886                                {kFinish, nullptr},
887                            }));
888 }
889 
TEST_F(DecreasingSizeRunsHeapTest,Mixed)890 TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
891   CallSequence call_sequence;
892   DecreasingSizeRunsHeap heap(
893       absl::make_unique<HeapCallRecorder>(&call_sequence));
894   heap.Alloc(buffer_a_, 10);
895   heap.Alloc(buffer_b_, 20);
896   heap.Free(buffer_b_, 20);
897 
898   heap.Alloc(buffer_c_, 30);
899   heap.Free(buffer_c_, 30);
900 
901   heap.Alloc(buffer_d_, 5);
902   heap.Free(buffer_d_, 5);
903   heap.Free(buffer_a_, 10);
904   heap.Finish();
905   // Runs of Allocs and Frees are sorted by decreasing size.
906   EXPECT_EQ(call_sequence, CallSequence({
907                                {kAlloc, buffer_b_},
908                                {kAlloc, buffer_a_},
909                                {kFree, buffer_b_},
910 
911                                {kAlloc, buffer_c_},
912                                {kFree, buffer_c_},
913 
914                                {kAlloc, buffer_d_},
915                                {kFree, buffer_a_},
916                                {kFree, buffer_d_},
917                                {kFinish, nullptr},
918                            }));
919 }
920 
921 class LazyBestFitHeapTest : public HeapAlgorithmTestBase {};
922 
TEST_F(LazyBestFitHeapTest,Empty)923 TEST_F(LazyBestFitHeapTest, Empty) {
924   LazyBestFitHeap heap(/*alignment=*/1);
925   const HeapSimulator::Result result = heap.Finish();
926   EXPECT_EQ(0, result.heap_size);
927   EXPECT_EQ(0, result.chunk_map.size());
928 }
929 
TEST_F(LazyBestFitHeapTest,Simple)930 TEST_F(LazyBestFitHeapTest, Simple) {
931   LazyBestFitHeap heap(/*alignment=*/1);
932   heap.Alloc(buffer_a_, 10);
933   heap.Alloc(buffer_b_, 20);
934   heap.Alloc(buffer_c_, 30);
935   heap.Alloc(buffer_d_, 30);
936   heap.Free(buffer_a_, 10);
937   heap.Free(buffer_b_, 20);
938   heap.Free(buffer_c_, 30);
939   heap.Free(buffer_d_, 30);
940 
941   const HeapSimulator::Result result = heap.Finish();
942   EXPECT_EQ(90, result.heap_size);
943   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
944   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
945   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
946   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
947 
948   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
949   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset);
950   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
951   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
952 }
953 
TEST_F(LazyBestFitHeapTest,Mixed)954 TEST_F(LazyBestFitHeapTest, Mixed) {
955   LazyBestFitHeap heap(/*alignment=*/1);
956   heap.Alloc(buffer_a_, 10);  // A lazy offset
957 
958   heap.Alloc(buffer_b_, 20);  // B lazy offset
959   heap.Free(buffer_b_, 20);   // B range = [0, 20)  free = [0, 20)
960 
961   heap.Alloc(buffer_c_, 30);  // C range = [0, 30)
962   heap.Free(buffer_c_, 30);   //                    free = [0, 30)
963 
964   heap.Alloc(buffer_d_, 5);  // D range = [0, 5)   free = [5, 30)
965   heap.Free(buffer_d_, 5);   //                    free = [0, 30)
966 
967   heap.Free(buffer_a_, 10);  // A range = [30, 10) free = [0, 40)
968 
969   const HeapSimulator::Result result = heap.Finish();
970   EXPECT_EQ(40, result.heap_size);
971   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
972   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
973   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
974   EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size);
975 
976   EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
977   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
978   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
979   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
980 }
981 
TEST_F(LazyBestFitHeapTest,BestFit)982 TEST_F(LazyBestFitHeapTest, BestFit) {
983   LazyBestFitHeap heap(/*alignment=*/1);
984 
985   // First alloc/free buffer_a_, to force a big free chunk to appear.
986   heap.Alloc(buffer_a_, 200);  // A lazy offset
987   heap.Free(buffer_a_, 200);   // A range = [0, 200)   free = [0, 200)
988 
989   // Now alloc a bunch of buffers that are allocated out of the free chunk.
990   heap.Alloc(buffer_b_, 30);  // B range = [0, 30)    free = [30, 200)
991   heap.Alloc(buffer_c_, 30);  // C range = [30, 60)   free = [60, 200)
992   heap.Alloc(buffer_d_, 20);  // D range = [60, 80)   free = [80, 200)
993   heap.Alloc(buffer_e_, 20);  // E range = [80, 100)  free = [100, 200)
994   heap.Alloc(buffer_f_, 10);  // F range = [100, 110) free = [110, 200)
995   heap.Alloc(buffer_g_, 10);  // G range = [110, 120) free = [120, 200)
996   heap.Alloc(buffer_h_, 80);  // H range = [120, 200)
997 
998   // Free buffers to create free chunks of different sizes.
999   heap.Free(buffer_c_, 30);  // free = [30, 60)
1000   heap.Free(buffer_e_, 20);  // free = [30, 60), [80, 100)
1001   heap.Free(buffer_g_, 10);  // free = [30, 60), [80, 100), [110, 120)
1002 
1003   // The best fit is picked out of the existing free chunks.
1004   heap.Alloc(buffer_i_, 15);  // I range = [80, 95)
1005 
1006   // The frees here ensure the buffer-coalescing logic is exercised.
1007   heap.Free(buffer_b_, 30);
1008   heap.Free(buffer_d_, 20);
1009   heap.Free(buffer_f_, 10);
1010   heap.Free(buffer_h_, 80);
1011   heap.Free(buffer_i_, 15);
1012 
1013   const HeapSimulator::Result result = heap.Finish();
1014   EXPECT_EQ(200, result.heap_size);
1015   EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size);
1016   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1017   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
1018   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
1019   EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size);
1020   EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size);
1021   EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size);
1022   EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size);
1023   EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size);
1024 
1025   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1026   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1027   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1028   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1029   EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset);
1030   EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset);
1031   EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset);
1032   EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset);
1033   EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset);
1034 }
1035 
TEST_F(LazyBestFitHeapTest,Lazy)1036 TEST_F(LazyBestFitHeapTest, Lazy) {
1037   LazyBestFitHeap heap(/*alignment=*/1);
1038 
1039   // First alloc some buffers, which are all lazily allocated offsets.
1040   heap.Alloc(buffer_a_, 10);
1041   heap.Alloc(buffer_b_, 5);
1042   heap.Alloc(buffer_c_, 10);
1043 
1044   // Now free some buffers, which forces offset assignment.
1045   heap.Free(buffer_a_, 10);  // A range = [0, 10)  free = [0, 10)
1046   heap.Free(buffer_c_, 10);  // C range = [10, 20) free = [0, 20)
1047 
1048   // If we hadn't lazily assigned offsets, the free chunk wouldn't be large
1049   // enough to hold the entire allocation.
1050   heap.Alloc(buffer_d_, 20);  // D range = [0, 20)
1051 
1052   heap.Free(buffer_b_, 5);  // B range = [20, 25)
1053   heap.Free(buffer_d_, 20);
1054 
1055   const HeapSimulator::Result result = heap.Finish();
1056   EXPECT_EQ(25, result.heap_size);
1057   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1058   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
1059   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1060   EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
1061 
1062   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1063   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset);
1064   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
1065   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1066 }
1067 
TEST_F(LazyBestFitHeapTest,ReuseLastFreeChunk)1068 TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) {
1069   LazyBestFitHeap heap(/*alignment=*/1);
1070 
1071   // First alloc/free buffer_a_, to force a big free chunk to appear.
1072   heap.Alloc(buffer_a_, 60);  // A lazy offset
1073   heap.Free(buffer_a_, 60);   // A range = [0, 60)   free = [0, 60)
1074 
1075   // Now alloc a bunch of buffers that are allocated out of the free chunk.
1076   heap.Alloc(buffer_b_, 10);  // B range = [0, 10)    free = [10, 60)
1077   heap.Alloc(buffer_c_, 20);  // C range = [10, 30)   free = [30, 60)
1078   heap.Alloc(buffer_d_, 30);  // D range = [30, 60)
1079 
1080   // Free buffers to create free chunks of different sizes.
1081   heap.Free(buffer_b_, 10);  // free = [0, 10)
1082   heap.Free(buffer_d_, 30);  // free = [0, 10), [30, 60)
1083 
1084   // No free chunks are large enough, but the last free chunk is adjacent to the
1085   // end of the heap, so we re-use that chunk.
1086   heap.Alloc(buffer_e_, 40);  // E range = [30, 70)
1087 
1088   heap.Free(buffer_c_, 20);
1089   heap.Free(buffer_e_, 40);
1090 
1091   const HeapSimulator::Result result = heap.Finish();
1092   EXPECT_EQ(70, result.heap_size);
1093   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size);
1094   EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size);
1095   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1096   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1097   EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size);
1098 
1099   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1100   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1101   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
1102   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset);
1103   EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset);
1104 }
1105 
TEST_F(LazyBestFitHeapTest,Alignment)1106 TEST_F(LazyBestFitHeapTest, Alignment) {
1107   LazyBestFitHeap heap(/*alignment=*/64);
1108 
1109   // First alloc some buffers, which are all lazily allocated offsets.
1110   heap.Alloc(buffer_a_, 10);
1111   heap.Alloc(buffer_b_, 5);
1112   heap.Alloc(buffer_c_, 10);
1113 
1114   // Now free some buffers, which forces offset assignment with alignment.
1115   heap.Free(buffer_a_, 10);  //  A range = [0, 10)    free = [0, 10)
1116   heap.Free(buffer_c_, 10);  //  C range = [64, 74)   free = [0, 74)
1117 
1118   // If we hadn't lazily assigned offsets, and accounted for alignment, the free
1119   // chunk wouldn't be large enough to hold the entire allocation.
1120   heap.Alloc(buffer_d_, 74);  // D range = [0, 74)    free = [)
1121 
1122   heap.Free(buffer_b_, 5);    // B range = [128, 133) free = [74, 133)
1123   heap.Alloc(buffer_e_, 23);  // E range = [128, 151) free = [74, 128)
1124 
1125   heap.Free(buffer_d_, 74);  //                       free = [0, 128)
1126   heap.Free(buffer_e_, 23);  //                       free = [0, 151)
1127 
1128   const HeapSimulator::Result result = heap.Finish();
1129   EXPECT_EQ(151, result.heap_size);
1130   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1131   EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
1132   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1133   EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size);
1134   EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size);
1135 
1136   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1137   EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset);
1138   EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset);
1139   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1140   EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
1141 }
1142 
1143 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
1144 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1145 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1146   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1147   const HeapSimulator::Result result = heap.Finish();
1148   EXPECT_EQ(0, result.heap_size);
1149   EXPECT_EQ(0, result.chunk_map.size());
1150 }
1151 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1152 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1153   // space
1154   //   ^
1155   //   |  +---a---+
1156   //   |      +-------+
1157   //   |      +---c---+
1158   //   |    +-------+
1159   //   |    |   b   |
1160   //   |    +-------+
1161   //   |         +-------+
1162   //   |         |       |
1163   //   |         |   d   |
1164   //   |         +-------+
1165   //   -----------------> time
1166   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1167   heap.Alloc(buffer_a_, 10);
1168   heap.Alloc(buffer_b_, 30);
1169   heap.Alloc(buffer_c_, 20);
1170   heap.Alloc(buffer_d_, 40);
1171   heap.Free(buffer_a_, 10);
1172   heap.Free(buffer_b_, 30);
1173   heap.Free(buffer_c_, 20);
1174   heap.Free(buffer_d_, 40);
1175 
1176   const HeapSimulator::Result result = heap.Finish();
1177   EXPECT_EQ(100, result.heap_size);
1178   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1179   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1180   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1181   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1182 
1183   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1184   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1185   EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1186   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1187 }
1188 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1189 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1190   // space
1191   //   ^
1192   //   |      +-------+
1193   //   |      +---b---+
1194   //   |            +-------+
1195   //   |            |       |
1196   //   |            |   d   |
1197   //   |  +---a---+ +-------+
1198   //   |
1199   //   |         +-------+
1200   //   |         |       |
1201   //   |         |   c   |
1202   //   |         |       |
1203   //   |         +-------+
1204   //   ---------------------> time
1205   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
1206   heap.Alloc(buffer_a_, 10);
1207   heap.Alloc(buffer_b_, 20);
1208   heap.Alloc(buffer_c_, 50);
1209   heap.Free(buffer_a_, 10);
1210   heap.Alloc(buffer_d_, 40);
1211   heap.Free(buffer_b_, 20);
1212   heap.Free(buffer_c_, 50);
1213   heap.Free(buffer_d_, 40);
1214 
1215   const HeapSimulator::Result result = heap.Finish();
1216   EXPECT_EQ(120, result.heap_size);
1217   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1218   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1219   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1220   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1221 
1222   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1223   EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1224   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1225   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1226 }
1227 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1228 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1229   // space
1230   //   ^
1231   //   |    +-------+
1232   //   |    +---b---+
1233   //   |         +-------+
1234   //   |         |   d   |
1235   //   | +--a--+ +-------+
1236   //   |      +-------+
1237   //   |      |       |
1238   //   |      |   c   |
1239   //   |      +-------+
1240   //   |           +-------+
1241   //   |           |       |
1242   //   |           |   e   |
1243   //   |           |       |
1244   //   |           +-------+
1245   //   ---------------------> time
1246   GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1247   heap.Alloc(buffer_a_, 10);
1248   heap.Alloc(buffer_b_, 20);
1249   heap.Alloc(buffer_c_, 40);
1250   heap.Free(buffer_a_, 10);
1251   heap.Alloc(buffer_d_, 30);
1252   heap.Alloc(buffer_e_, 50);
1253   heap.Free(buffer_b_, 20);
1254   heap.Free(buffer_c_, 40);
1255   heap.Free(buffer_d_, 30);
1256   heap.Free(buffer_e_, 50);
1257 
1258   const HeapSimulator::Result result = heap.Finish();
1259   EXPECT_EQ(140, result.heap_size);
1260   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1261   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1262   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1263   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1264   EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1265 
1266   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1267   EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1268   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1269   EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1270   EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1271 }
1272 
1273 }  // namespace
1274 }  // namespace xla
1275