1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
28 #include "tensorflow/compiler/xla/util.h"
29 
30 namespace xla {
31 
32 using absl::flat_hash_map;
33 using absl::flat_hash_set;
34 
OverlapsWith(Chunk other_chunk) const35 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
36   CHECK_NE(size, 0);
37   CHECK_NE(other_chunk.size, 0);
38   return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
39 }
40 
41 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)42 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
43     const HloSchedule& schedule,
44     const LogicalBuffer::SizeFunction& size_function) {
45   if (schedule.empty()) {
46     return 0;
47   }
48   const HloModule* module = schedule.module();
49 
50   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
51                       HloAliasAnalysis::Run(module));
52 
53   // The absolute minimum memory required for a given sequence of instructions
54   // is determined by the sequence of Alloc and Free calls on a simulated heap,
55   // ignoring fragmentation. We run the heap simulation on the whole module,
56   // rather than summing each computation, since it gives us a better lower
57   // bound, by minimizing the liveness of sub-computations.
58   TF_ASSIGN_OR_RETURN(
59       HeapSimulator::Result<HloValue> result,
60       HeapSimulator::Run(
61           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
62           schedule, *alias_analysis, size_function));
63   return result.heap_size;
64 }
65 
66 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)67 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
68     const HloComputation& computation, const HloInstructionSequence& sequence,
69     const HloAliasAnalysis& alias_analysis,
70     const LogicalBuffer::SizeFunction& size_function,
71     const absl::flat_hash_map<const HloComputation*, int64>*
72         memory_by_computation) {
73   TF_ASSIGN_OR_RETURN(
74       HeapSimulator::Result<HloValue> result,
75       HeapSimulator::Run(
76           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
77           sequence, alias_analysis, size_function, HeapSimulator::Options(),
78           memory_by_computation));
79   return result.heap_size;
80 }
81 
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)82 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
83     const HloComputation& computation, const HloInstructionSequence& sequence,
84     const HloAliasAnalysis& alias_analysis,
85     const LogicalBuffer::SizeFunction& size_function,
86     const HloSchedule* schedule) {
87   TF_ASSIGN_OR_RETURN(
88       HeapSimulator::Result<HloValue> result,
89       HeapSimulator::Run(
90           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
91           sequence, alias_analysis, size_function, schedule,
92           HeapSimulator::Options()));
93   return result.heap_size;
94 }
95 
96 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)97 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
98     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
99     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
100     const BufferValue::SizeFunction& size_fn, const Options& options) {
101   HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
102   const HloComputation* entry_computation = module.entry_computation();
103   const HloInstructionSequence& instruction_sequence =
104       schedule.sequence(entry_computation);
105   TF_ASSIGN_OR_RETURN(
106       std::unique_ptr<HloLiveRange> hlo_live_range,
107       HloLiveRange::Run(schedule, alias_analysis, entry_computation));
108   TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
109                                          instruction_sequence, alias_analysis,
110                                          hlo_live_range.get()));
111   return heap.Finish();
112 }
113 
114 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)115 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
116     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
117     const HloComputation& computation,
118     const HloInstructionSequence& instruction_sequence,
119     const HloAliasAnalysis& alias_analysis,
120     const BufferValue::SizeFunction& size_fn, const Options& options,
121     const absl::flat_hash_map<const HloComputation*, int64>*
122         memory_by_computation) {
123   HeapSimulator heap(std::move(algorithm), size_fn, options,
124                      /*schedule=*/nullptr, memory_by_computation);
125   HloSchedule schedule(computation.parent());
126   schedule.set_sequence(&computation, instruction_sequence);
127   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
128                       HloLiveRange::Run(schedule, alias_analysis, &computation,
129                                         /*module_scoped_analysis=*/false));
130   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
131                                          alias_analysis, hlo_live_range.get()));
132   return heap.Finish();
133 }
134 
135 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)136 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
137     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
138     const HloComputation& computation,
139     const HloInstructionSequence& instruction_sequence,
140     const HloAliasAnalysis& alias_analysis,
141     const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
142     const Options& options) {
143   HeapSimulator heap(std::move(algorithm), size_fn, options,
144                      /*schedule=*/schedule, nullptr);
145   TF_ASSIGN_OR_RETURN(
146       std::unique_ptr<HloLiveRange> hlo_live_range,
147       HloLiveRange::Run(*schedule, alias_analysis, &computation));
148   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
149                                          alias_analysis, hlo_live_range.get()));
150   return heap.Finish();
151 }
152 
153 // Runs a heap simulation for the given 'computation', assuming the given
154 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)155 Status HeapSimulator::RunComputation(
156     const HloComputation& computation,
157     const HloInstructionSequence& instruction_sequence,
158     const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
159   XLA_VLOG_LINES(1, computation.parent()->ToString());
160   XLA_VLOG_LINES(2, computation.ToString());
161 
162   VLOG(1) << hlo_live_range->ToString();
163 
164   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
165 
166   // Record the buffer define/free event for each time step. We free all
167   // remaining buffers (entry parameter, etc) after the program has finished
168   // running, so we set the size of to program_end_time + 1.
169   std::vector<std::vector<const HloValue*>> buffers_defined(
170       hlo_live_range->schedule_end_time() + 1);
171   std::vector<std::vector<const HloValue*>> buffers_freed(
172       hlo_live_range->schedule_end_time() + 1);
173 
174   // values_to_assign tracks the HloValues that we need to assign a buffer to.
175   // Note that we only need to assign a buffer to a value when both of the
176   // following conditions are met:
177   //
178   // - The user specifically asks us to assign a buffer to a set of HloValues,
179   // and the value is in the set. If the user don't provide such a set, by
180   // default we assign buffer to all HloValues.
181   //
182   // - If the instruction is in a nested call of the current computation, only
183   // assign a buffer if we are doing global heap simulation.
184   std::vector<const HloValue*> values_to_assign;
185   values_to_assign.reserve(dataflow_analysis.values().size());
186 
187   for (const HloValue* value : dataflow_analysis.values()) {
188     // Ignore buffers that are not tracked.
189     if (hlo_live_range->instruction_schedule().count(
190             value->defining_instruction()) == 0) {
191       continue;
192     }
193     if (IgnoreBuffer(value)) {
194       continue;
195     }
196     values_to_assign.push_back(value);
197   }
198 
199   auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
200 
201   absl::c_sort(values_to_assign,
202                [&](const HloValue* value1, const HloValue* value2) {
203                  const auto& live_range1 = buffer_live_ranges.at(value1);
204                  const auto& live_range2 = buffer_live_ranges.at(value2);
205                  return std::forward_as_tuple(live_range1.start,
206                                               live_range1.end, value1->id()) <
207                         std::forward_as_tuple(live_range2.start,
208                                               live_range2.end, value2->id());
209                });
210 
211   // For each value that we need to assign a buffer to, add the define and free
212   // events.
213   for (const HloValue* value : values_to_assign) {
214     auto live_range = buffer_live_ranges.at(value);
215     buffers_defined[live_range.start].push_back(value);
216     buffers_freed[live_range.end].push_back(value);
217   }
218 
219   // All HloValues in a hlo buffer should be allocated to the same address. This
220   // map tracks the first value that got allocated in a buffer.
221   absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
222 
223   VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
224 
225   // Go through each step in the program and replay each buffer define and free
226   // events.
227   for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
228     VLOG(1) << "Time step: " << i;
229 
230     for (const HloValue* value : buffers_defined[i]) {
231       bool shared = false;
232       VLOG(1) << "Start buffer: " << value->ToShortString();
233       const HloBuffer* hlo_buffer =
234           &alias_analysis.GetBufferContainingValue(*value);
235       if (first_allocated_value.count(hlo_buffer) != 0) {
236         // We've already assigned an address for another value in this HloBuffer
237         // (HloBuffer holds several aliased HloValues). All values in a buffer
238         // should be assigned the same address. Find the one that's already
239         // allocated and reuse its address.
240         ShareBuffer(value, first_allocated_value[hlo_buffer],
241                     value->instruction());
242         VLOG(1) << "  ShareWith"
243                 << first_allocated_value[hlo_buffer]->ToShortString();
244         continue;
245       }
246       if (options_.may_reuse_operand_buffers &&
247           hlo_buffer->values().size() == 1) {
248         // We don't support sharing an aliased buffer
249         // (hlo_buffer->values().size() > 1) with its operand.
250         for (const HloInstruction* operand : value->instruction()->operands()) {
251           const HloValueSet operand_value_set =
252               dataflow_analysis.GetValueSet(operand);
253           for (const HloValue* operand_value : operand_value_set.values()) {
254             const HloBuffer* operand_buffer =
255                 &alias_analysis.GetBufferContainingValue(*operand_value);
256             if (operand_buffer->values().size() > 1) {
257               continue;
258             }
259             auto it = buffer_live_ranges.find(operand_value);
260             if (it == buffer_live_ranges.end()) {
261               continue;
262             }
263 
264             auto& operand_live_range = it->second;
265 
266             auto& user_live_range = buffer_live_ranges[value];
267 
268             // Can only share buffers that are about to be freed.
269             if (operand_live_range.end != i) {
270               continue;
271             }
272 
273             if (IgnoreBuffer(operand_value)) {
274               continue;
275             }
276 
277             if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
278               // If the operand buffer is not being freed (either because it has
279               // existing users, or it has been reused by other buffers), don't
280               // consider the operand as a candidate of buffer sharing.
281               continue;
282             }
283 
284             // The instruction that defines the operand value can be different
285             // from the actual operand, if directly passing the defining
286             // instruction into "CanShareOperandBufferWithUser" it creates a
287             // check failure. The first condition guards against that case.
288             if (value->instruction()->IsUserOf(operand_value->instruction()) &&
289                 value->instruction()->opcode() != HloOpcode::kCopy &&
290                 dataflow_analysis.CanShareOperandBufferWithUser(
291                     operand_value->instruction(), operand_value->index(),
292                     value->instruction(), value->index())) {
293               // Remove the operand buffer right before sharing (allocating) a
294               // new one.
295               Free(operand_value, operand_value->instruction());
296               buffers_freed[i].erase(
297                   std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
298                               operand_value),
299                   buffers_freed[i].end());
300               ShareBuffer(value, operand_value, value->instruction());
301               // The live range of the operand buffer is now extended to the end
302               // of the current instruction.
303               operand_live_range.end = user_live_range.end;
304               VLOG(1) << "Sharing " << value->ToShortString() << " with "
305                       << operand_value->ToShortString()
306                       << ", size:" << size_fn_(*value);
307               shared = true;
308               break;
309             }
310           }
311           if (shared) {
312             break;
313           }
314         }
315       }
316       if (!shared) {
317         Alloc(value, value->instruction());
318         first_allocated_value[hlo_buffer] = value;
319       }
320     }
321 
322     if (!buffers_freed[i].empty()) {
323       VLOG(1) << "Free Buffer: ";
324     }
325     for (const HloValue* value : buffers_freed[i]) {
326       VLOG(1) << "  " << value->ToShortString();
327 
328       Free(value, value->instruction());
329     }
330   }
331   return Status::OK();
332 }
333 
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)334 HeapSimulator::HeapSimulator(
335     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
336     const BufferValue::SizeFunction& size_fn, const Options& options,
337     const HloSchedule* schedule,
338     const absl::flat_hash_map<const HloComputation*, int64>*
339         memory_by_computation)
340     : no_fragmentation_stats_(
341           absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
342       algorithm_(std::move(algorithm)),
343       size_fn_(size_fn),
344       options_(options),
345       schedule_(schedule),
346       memory_by_computation_(memory_by_computation) {
347   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
348 }
349 
~HeapSimulator()350 HeapSimulator::~HeapSimulator() {}
351 
IgnoreBuffer(const HloValue * buffer) const352 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
353   // Buffers for constants are ignored unless the alloc_constants option is
354   // set. Also ignore buffers that we're not meant to assign.
355   //
356   // TODO(b/32248867): For consistency, constants should get allocations.
357   if (!options_.alloc_constants &&
358       buffer->instruction()->opcode() == HloOpcode::kConstant) {
359     return true;
360   }
361   return options_.buffers_to_assign != nullptr &&
362          !options_.buffers_to_assign->contains(buffer);
363 }
364 
365 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)366 void HeapSimulator::Alloc(const HloValue* buffer,
367                           const HloInstruction* instruction) {
368   CHECK(!allocated_buffers_.contains(buffer))
369       << "Alloc called on allocated buffer: " << *buffer;
370   CHECK(!freed_buffers_.contains(buffer))
371       << "Alloc called on freed buffer: " << *buffer;
372 
373   allocated_buffers_.insert(buffer);
374   const int64 size = size_fn_(*buffer);
375   algorithm_->Alloc(buffer, size);
376   no_fragmentation_stats_->Alloc(buffer, size);
377   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
378                  nullptr);
379 }
380 
381 // Free calls the underlying algorithm for non-shared buffers, and for shared
382 // buffers whose group liveness has expired.  Shared group liveness is tracked
383 // by maintaining a refcount; the Free call on the last buffer in the group
384 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)385 void HeapSimulator::Free(const HloValue* buffer,
386                          const HloInstruction* instruction) {
387   const int64 size = size_fn_(*buffer);
388   algorithm_->Free(buffer, size);
389   no_fragmentation_stats_->Free(buffer, size);
390   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
391 }
392 
393 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
394 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
395 // to Alloc.  The 'shared' buffer must be a previously allocated or shared
396 // buffer. Both 'buffer' and 'shared' will be associated with the same
397 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)398 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
399                                 const HloInstruction* instruction) {
400   algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
401   no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
402   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
403                  shared);
404 }
405 
Finish()406 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
407   Result<HloValue> result = algorithm_->Finish();
408 
409   // Post-process the result to add chunks for shared buffers.  An empty chunk
410   // map means that either no buffers were allocated, or the heap was only
411   // collecting statistics, e.g. NoFragmentationStatsHeap.
412   size_t total_chunk_count = absl::c_accumulate(
413       result.heap_results, static_cast<size_t>(0),
414       [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
415         return lhs + rhs.chunk_map.size();
416       });
417   if (total_chunk_count != 0) {
418     // If we were told to assign specific buffers, make sure we've assigned
419     // exactly that many buffers.
420     if (options_.buffers_to_assign != nullptr) {
421       CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
422     }
423   }
424 
425   // Fragmentation is the difference between the actual and ideal sizes.
426   const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
427   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
428 
429   // Copy the debug trace we collected to the final result.
430   result.debug_trace.Swap(&debug_trace_);
431 
432   return result;
433 }
434 
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)435 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
436                                    const HloValue* buffer,
437                                    const HloInstruction* instruction,
438                                    const HloValue* share_with_canonical) {
439   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
440   event->set_kind(kind);
441   event->set_buffer_id(buffer->id());
442   event->set_computation_name(instruction->parent()->name());
443   event->set_instruction_name(instruction->name());
444   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
445     CHECK(share_with_canonical != nullptr);
446     event->set_share_with_canonical_id(share_with_canonical->id());
447   } else {
448     CHECK(share_with_canonical == nullptr);
449   }
450 }
451 
452 template <typename BufferType>
Alloc(const BufferType * buffer,int64 size)453 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
454                                                  int64 size) {
455   current_heap_size_ += size;
456   if (current_heap_size_ > max_heap_size_) {
457     max_heap_size_ = current_heap_size_;
458   }
459 }
460 
461 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)462 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
463     const HloInstruction* instruction, int64 alloc_size_by_instruction,
464     const absl::flat_hash_map<const HloComputation*, int64>&
465         memory_by_computation) {
466   // We only count the memory usage of the largest subcomputation, instead of
467   // adding them all, because subcomputations won't execute in parallel.
468   int64 max_subcomputation_bytes = 0;
469   for (const auto* c : instruction->called_computations()) {
470     auto it = memory_by_computation.find(c);
471     if (it != memory_by_computation.end()) {
472       int64 subcomputation_bytes = it->second;
473       if (subcomputation_bytes > max_subcomputation_bytes) {
474         max_subcomputation_bytes = subcomputation_bytes;
475       }
476     }
477   }
478   if (max_subcomputation_bytes > 0 &&
479       (instruction->opcode() == HloOpcode::kWhile ||
480        instruction->opcode() == HloOpcode::kCall ||
481        instruction->opcode() == HloOpcode::kConditional)) {
482     // The output buffer of while/call/conditional is always aliased with the
483     // output buffer of the root instruction in the body. Don't double count.
484     max_subcomputation_bytes -= alloc_size_by_instruction;
485   }
486   max_heap_size_ =
487       std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
488 }
489 
490 template <typename BufferType>
Free(const BufferType * buffer,int64 size)491 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
492                                                 int64 size) {
493   current_heap_size_ -= size;
494 }
495 
496 template <typename BufferType>
497 HeapSimulator::Result<BufferType>
Finish()498 NoFragmentationStatsHeap<BufferType>::Finish() {
499   // The result.chunk_map is empty, since we only collect stats, and don't
500   // actually compute chunk assignments.
501   Result result;
502   result.heap_size = max_heap_size_;
503   return result;
504 }
505 
506 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64 alignment,Type type)507 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
508     int64 alignment, Type type)
509     : alignment_(alignment) {
510   if (type == kTemporal) {
511     buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
512   } else {
513     CHECK(type == kSpatial);
514     buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
515   }
516 }
517 
518 template <typename BufferType>
519 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const520 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
521     const {
522   return [&](const BufferInterval& x, const BufferInterval& y) {
523     int64 x_end = x.end;
524     for (auto colocation : GetTransitiveColocations(x)) {
525       x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
526     }
527 
528     int64 y_end = y.end;
529     for (auto colocation : GetTransitiveColocations(y)) {
530       y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
531     }
532 
533     if (x_end - x.start != y_end - y.start) {
534       return x_end - x.start > y_end - y.start;
535     }
536 
537     if (x.size != y.size) {
538       return x.size > y.size;
539     }
540     return *x.buffer < *y.buffer;
541   };
542 }
543 
544 template <typename BufferType>
545 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
546     BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()547 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
548   return [&](const BufferInterval& x, const BufferInterval& y) {
549     if (x.size != y.size) {
550       return x.size > y.size;
551     }
552     if (x.end - x.start != y.end - y.start) {
553       return x.end - x.start > y.end - y.start;
554     }
555     return *x.buffer < *y.buffer;
556   };
557 }
558 
559 template <typename BufferType>
Alloc(const BufferType * buffer,int64 size)560 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
561     const BufferType* buffer, int64 size) {
562   // Degenerate case: 0-sized buffers are always allocated at offset 0.
563   if (size == 0) {
564     result_.chunk_map.emplace(buffer, Chunk{0, 0});
565     return;
566   }
567 
568   auto emplace_result = buffer_intervals_.emplace(
569       buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
570   DCHECK(emplace_result.second);
571   ++current_time_;
572 }
573 
574 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64 size)575 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
576     const BufferType* buffer, const BufferType* share_with, int64 size) {
577   // Degenerate case: 0-sized buffers are always allocated at offset 0.
578   if (size == 0) {
579     result_.chunk_map.emplace(buffer, Chunk{0, 0});
580     return;
581   }
582   DCHECK_NE(buffer_intervals_.count(share_with), 0);
583   buffer_intervals_[share_with].colocations.push_back(buffer);
584   auto emplace_result = buffer_intervals_.emplace(
585       buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
586   DCHECK(emplace_result.second);
587   ++current_time_;
588 }
589 
590 template <typename BufferType>
591 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const592 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
593     const BufferInterval& interval) const {
594   absl::flat_hash_set<const BufferType*> result;
595   std::vector<const BufferInterval*> worklist = {&interval};
596   while (!worklist.empty()) {
597     const BufferInterval* item = worklist.back();
598     worklist.pop_back();
599     for (const BufferType* buffer_colocated : item->colocations) {
600       result.insert(buffer_colocated);
601       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
602     }
603   }
604 
605   return result;
606 }
607 
608 template <typename BufferType>
Free(const BufferType * buffer,int64 size)609 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
610                                                        int64 size) {
611   // Degenerate case: 0-sized buffers are always allocated at offset 0.
612   if (size == 0) {
613     return;
614   }
615   BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
616   DCHECK_EQ(buffer_interval.buffer, buffer);
617   DCHECK_EQ(buffer_interval.size, size);
618   DCHECK_EQ(buffer_interval.end, -1);
619   if (buffer_interval.end != -1) {
620     return;
621   }
622   buffer_interval.end = current_time_;
623   ++current_time_;
624 }
625 
626 using Chunk = HeapSimulator::Chunk;
627 
Add(int64 start,int64 end,const Chunk & chunk)628 void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
629   node_storage_.emplace_back(BufferIntervalTreeNode{
630       start, end, end, chunk,
631       /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
632   if (root_ == nullptr) {
633     root_ = &node_storage_.back();
634     // This is root.
635     return;
636   }
637 
638   BufferIntervalTreeNode* parent = root_;
639   while (true) {
640     parent->subtree_end = std::max(parent->subtree_end, end);
641     if (parent->start > start) {
642       if (parent->left == nullptr) {
643         parent->left = &node_storage_.back();
644         node_storage_.back().parent = parent;
645         return;
646       }
647       parent = parent->left;
648     } else {
649       if (parent->right == nullptr) {
650         parent->right = &node_storage_.back();
651         node_storage_.back().parent = parent;
652         return;
653       }
654       parent = parent->right;
655     }
656   }
657 }
658 
Remove(int64 start,int64 end,const Chunk & chunk)659 bool BufferIntervalTree::Remove(int64 start, int64 end, const Chunk& chunk) {
660   BufferIntervalTreeNode* to_delete = root_;
661   while (to_delete != nullptr) {
662     if (to_delete->start == start && to_delete->end == end &&
663         to_delete->chunk.offset == chunk.offset) {
664       break;
665     }
666     if (start < to_delete->start) {
667       to_delete = to_delete->left;
668     } else {
669       to_delete = to_delete->right;
670     }
671   }
672   if (to_delete == nullptr) {
673     // Nothing to delete.
674     return false;
675   }
676   // Found the node to be deleted, enter deletion sequence.
677 
678   // Recursively traverse the parents of node and fix up the `subtree_end`
679   // invariant of a node. Recursive lambda need an explicit
680   // std::function declaration.
681   std::function<void(BufferIntervalTreeNode*)> fix_up =
682       [&](BufferIntervalTreeNode* node) {
683         if (node == nullptr) {
684           return;
685         }
686         node->subtree_end = node->end;
687         if (node->left) {
688           node->subtree_end =
689               std::max(node->subtree_end, node->left->subtree_end);
690         }
691         if (node->right) {
692           node->subtree_end =
693               std::max(node->subtree_end, node->right->subtree_end);
694         }
695         // Recursively go up.
696         fix_up(node->parent);
697       };
698 
699   if (to_delete->right == nullptr) {
700     // to_delete has no right child, simply move up left child of to_delete if
701     // any.
702     //
703     // Turn:
704     //      parent
705     //       /
706     // to_delete
707     //  /      \
708     // left    nullptr
709     //
710     // Into:
711     //      parent
712     //      /
713     //    left
714     if (root_ == to_delete) {
715       // Deleting root is simply reseting root;
716       root_ = to_delete->left;
717       return true;
718     }
719 
720     if (to_delete == to_delete->parent->left) {
721       // to_delete is left child of parent.
722       to_delete->parent->left = to_delete->left;
723     }
724     if (to_delete == to_delete->parent->right) {
725       // to_delete is right child of parent.
726       to_delete->parent->right = to_delete->left;
727     }
728     // Rewire parent to the node being moved up.
729     if (to_delete->left) {
730       to_delete->left->parent = to_delete->parent;
731     }
732     // Fix up starting from subroot.
733     fix_up(to_delete);
734   } else {
735     // 1. Find left-most node of the right subtree, promote it to the position
736     // of to_delete.
737     BufferIntervalTreeNode* to_promote = to_delete->right;
738     while (to_promote->left != nullptr) {
739       // Go to left-most subtree.
740       to_promote = to_promote->left;
741     }
742 
743     // 2. Copy the content of `to_promote` to `to_delete`.
744     to_delete->start = to_promote->start;
745     to_delete->end = to_promote->end;
746     // This is incorrect but we will fix this up later in the `fix_up`
747     // procedure.
748     to_delete->subtree_end = to_promote->subtree_end;
749     to_delete->chunk = to_promote->chunk;
750     auto to_promote_parent = to_promote->parent;
751     // 3. Move the right child of `to_promote` up if there is any.
752     //
753     // Turn
754     //
755     // to_delete
756     //         \
757     //        to_promote_parent
758     //         /
759     //    to_promote
760     //          \
761     //          right
762     // into
763     //
764     // to_promote
765     //         \
766     //         to_promote_parent
767     //         /
768     //      right
769     if (to_promote_parent->left == to_promote) {
770       to_promote_parent->left = to_promote->right;
771     } else {
772       to_promote_parent->right = to_promote->right;
773     }
774     if (to_promote->right) {
775       // Set correct parent.
776       to_promote->right->parent = to_promote_parent;
777     }
778     // 4. Recursive fix up the `subtree_end` starting from
779     // `to_promote_parent`.
780     fix_up(to_promote_parent);
781   }
782   // Don't free the entry in node_storage_ until we free the entire tree.
783   return true;
784 }
785 
ChunksOverlappingInTime(int64 start,int64 end) const786 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
787     int64 start, int64 end) const {
788   std::vector<Chunk> result;
789   if (root_ == nullptr) {
790     return result;
791   }
792   std::vector<const BufferIntervalTreeNode*> visiting_stack;
793   visiting_stack.push_back(root_);
794   while (!visiting_stack.empty()) {
795     const BufferIntervalTreeNode* top = visiting_stack.back();
796     visiting_stack.pop_back();
797     if (start > top->subtree_end) {
798       continue;
799     }
800     if (top->left != nullptr) {
801       visiting_stack.push_back(top->left);
802     }
803     if (top->start <= end && top->end >= start) {
804       result.push_back(top->chunk);
805     }
806     if (end < top->start) {
807       continue;
808     }
809     if (top->right != nullptr) {
810       visiting_stack.push_back(top->right);
811     }
812   }
813   return result;
814 }
815 
816 template <typename BufferType>
817 HeapSimulator::Result<BufferType>
Finish()818 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
819   std::vector<BufferInterval> sorted_buffer_intervals =
820       GetSortedBufferIntervals();
821 
822   for (auto& buffer_interval : sorted_buffer_intervals) {
823     if (!buffer_interval.need_allocation) {
824       continue;
825     }
826 
827     ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
828     // This implementation of the heap algorithm does not have a notion of
829     // maximum heap size, so it just commits.
830     CommitChunk(buffer_interval, chunk_candidate);
831   }
832   VLOG(1) << "result heap_size: " << result_.heap_size;
833   Result result;
834   result.heap_size = result_.heap_size;
835   result.heap_results.emplace_back(result_);
836   return result;
837 }
838 
839 template <typename BufferType>
840 std::vector<
841     typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const842 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
843   std::vector<BufferInterval> sorted_buffer_intervals;
844   for (auto& entry : buffer_intervals_) {
845     sorted_buffer_intervals.push_back(entry.second);
846   }
847   absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
848 
849   return sorted_buffer_intervals;
850 }
851 
852 template <typename BufferType>
853 typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64 preferred_offset) const854 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
855     const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
856     int64 preferred_offset) const {
857   VLOG(1) << "Finding chunks for buffer: "
858           << buffer_interval.buffer->ToString();
859   VLOG(1) << "Size " << buffer_interval.size << ", start "
860           << buffer_interval.start << ", end " << buffer_interval.end;
861   auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
862       buffer_interval.start, buffer_interval.end);
863   // Get all colocated buffers and gather all interferenced chunks.
864   //
865   // Imagine that we've already allocated three chunks : a, b and c.  And now
866   // we want to allocate d. Since e is colocated with d, we have to allocate
867   // chunks for them together at the same address. To do this, we first gather
868   // all chunks that overlap with d and e on the time dimension, in this case
869   // the overlapped chunks are a and b (c doesn't overlap with either of d and
870   // e), then find create a new chunk that doesn't overlap with a and b on the
871   // space dimension.
872   //
873   // space
874   //   ^
875   //   |+--d---+      +---e---+
876   //   |
877   //   |+---+  +---------------+  +-------+
878   //   ||   |  |               |  |       |
879   //   ||   |  |               |  |       |
880   //   |+-a-+  +-------b-------+  +---c---+
881   //   ----------------------------------------> time
882   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
883     auto colocation_interval = buffer_intervals_.at(colocation);
884     auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
885         colocation_interval.start, colocation_interval.end);
886     VLOG(1) << "  Alias size " << colocation_interval.size << ", start "
887             << colocation_interval.start << ", end " << colocation_interval.end
888             << " " << colocation_interval.buffer->ToString();
889     chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
890                                       colocation_overlapping.begin(),
891                                       colocation_overlapping.end());
892   }
893   absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
894     return x.offset < y.offset;
895   });
896 
897   // Find the minimum free chunk that can hold this buffer.
898   ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
899   Chunk& min_fit_chunk = chunk_candidate.chunk;
900   int64 preferred_chunk_end = preferred_offset + buffer_interval.size;
901   auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
902     if (free_size < buffer_interval.size) {
903       return;
904     }
905 
906     // If a preferred offset is provided, pick that offset.
907     if (free_offset <= preferred_offset &&
908         free_offset + free_size >= preferred_chunk_end) {
909       min_fit_chunk = {preferred_offset, buffer_interval.size};
910     } else if (free_offset + free_size == result_.heap_size &&
911                free_offset <= preferred_offset) {
912       // If the free offset is at the very end and if the preferred offset lies
913       // in this, pick the preferred offset and grow the heap.
914       min_fit_chunk = {preferred_offset, buffer_interval.size};
915       chunk_candidate.heap_size = preferred_chunk_end;
916     }
917 
918     // Pick the min-fit chunk only if we didn't have a preferred offset or a
919     // chunk at the preferred offset hasn't been found.
920     if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
921         free_size < min_fit_chunk.size) {
922       min_fit_chunk = {free_offset, free_size};
923     }
924   };
925 
926   int64 offset = 0;
927   for (auto& chunk : chunks_overlapping_in_time) {
928     if (offset < chunk.offset) {
929       use_free_chunk_if_smaller(offset, chunk.offset - offset);
930     }
931     offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
932   }
933   use_free_chunk_if_smaller(offset, result_.heap_size - offset);
934   // When preferred offset is provided and the preferred offset is larger than
935   // the current heap size, simply use the preferred offset provided.
936   if (result_.heap_size <= preferred_offset) {
937     chunk_candidate.heap_size = preferred_chunk_end;
938     min_fit_chunk = {preferred_offset, buffer_interval.size};
939   }
940 
941   if (min_fit_chunk.offset == -1) {
942     // Increase the heap size to fit in the last free chunk.
943     chunk_candidate.heap_size = offset + buffer_interval.size;
944     min_fit_chunk = {offset, buffer_interval.size};
945   }
946 
947   min_fit_chunk.size = buffer_interval.size;
948   return chunk_candidate;
949 }
950 
951 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate chunk_candidate)952 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
953     const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
954         buffer_interval,
955     GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
956         chunk_candidate) {
957   // Update the maximum heap size according to the one determined by the chunk
958   // candidate.
959   result_.heap_size = chunk_candidate.heap_size;
960   interval_tree_.Add(buffer_interval.start, buffer_interval.end,
961                      chunk_candidate.chunk);
962   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
963     AddToChunkMap(colocation, chunk_candidate.chunk);
964     auto colocation_interval = buffer_intervals_[colocation];
965     interval_tree_.Add(colocation_interval.start, colocation_interval.end,
966                        chunk_candidate.chunk);
967   }
968 
969   AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
970 }
971 
972 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)973 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
974     const BufferType* buffer, Chunk chunk) {
975   const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
976   DCHECK(emplace_result.second);
977 }
978 
979 HeapSimulator::Result<HloValue>
Finish()980 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
981   std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
982   // Convert into std::list so that erase() is O(1).
983   std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
984                                                     sorted_buffer_vec.end());
985 
986   // Use do-while here, because we need to create 1 heap in `multi_heap_result`
987   // even if `sorted_buffer_intervals` is empty.
988   Result multi_heap_result;
989   do {
990     // Place buffers into the currently processed heap as many as possible.
991     for (auto it = sorted_buffer_intervals.begin();
992          it != sorted_buffer_intervals.end();) {
993       BufferInterval buffer_interval = *it;
994       if (!buffer_interval.need_allocation) {
995         it = sorted_buffer_intervals.erase(it);
996         continue;
997       }
998       if (buffer_interval.size > size_limit_per_heap_) {
999         LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1000                      << " larger than the per-heap size limit "
1001                      << size_limit_per_heap_;
1002       }
1003 
1004       ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
1005       if (chunk_candidate.heap_size <= size_limit_per_heap_ ||
1006           // Commit the chunk as long as the heap is empty. We do this because
1007           // we want the size constraint to be soft, meaning that results are
1008           // successfully generated even if there are some buffer sizes larger
1009           // than the given constraint size.
1010           result_.heap_size == 0) {
1011         CommitChunk(buffer_interval, chunk_candidate);
1012         it = sorted_buffer_intervals.erase(it);
1013         continue;
1014       }
1015 
1016       ++it;
1017     }
1018     // Collect the result from the currently processed heap and reset the heap
1019     // states.
1020     multi_heap_result.heap_size += result_.heap_size;
1021     multi_heap_result.heap_results.push_back(std::move(result_));
1022     result_ = {};
1023     interval_tree_ = {};
1024   } while (!sorted_buffer_intervals.empty());
1025 
1026   VLOG(1) << "Number of heaps produced = "
1027           << multi_heap_result.heap_results.size();
1028   return multi_heap_result;
1029 }
1030 
1031 template <typename BufferType>
1032 HeapSimulator::Result<BufferType>
Finish()1033 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1034   DCHECK(!algorithms_.empty());
1035   std::vector<Result> results(algorithms_.size());
1036   int64 min_size = INT64_MAX;
1037   int min_size_index = -1;
1038   for (int i = 0; i < algorithms_.size(); ++i) {
1039     results[i] = algorithms_[i]->Finish();
1040     if (results[i].heap_size < min_size) {
1041       min_size = results[i].heap_size;
1042       min_size_index = i;
1043     }
1044   }
1045 
1046   DCHECK_GE(min_size_index, 0);
1047   return results[min_size_index];
1048 }
1049 
1050 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1051 template class GlobalDecreasingSizeBestFitHeap<
1052     MemorySpaceAssignmentRepacker::AllocationBlock>;
1053 template class ChooseBestHeapAlgorithm<HloValue>;
1054 
1055 }  // namespace xla
1056