1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the data returned by the XLA buffer assignment packages.
17 
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19 
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
33 #include "tensorflow/compiler/xla/service/heap_simulator.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
36 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
37 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/hash/hash.h"
45 #include "tensorflow/core/lib/strings/numbers.h"
46 
47 namespace xla {
48 namespace {
49 
50 using absl::flat_hash_map;
51 using absl::flat_hash_set;
52 using absl::StrAppend;
53 using absl::StrAppendFormat;
54 using ::tensorflow::strings::HumanReadableNumBytes;
55 
56 // Given the interference map of a graph (the list of interfering node indices
57 // for each node), perform graph coloring such that interfering nodes are
58 // assigned to different colors. Returns the assigned color of the nodes, where
59 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)60 std::vector<int64> ColorInterferenceGraph(
61     const std::vector<std::vector<int64>>& interference_map) {
62   const int64 node_count = interference_map.size();
63 
64   // Sort the nodes such that we assign nodes with more interference first. This
65   // relies on the common heuristic of assigning the most constrained node
66   // first, but it would be good to investigate other ordering heuristics too.
67   std::vector<int64> nodes(node_count);
68   std::iota(nodes.begin(), nodes.end(), 0);
69   absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
70     return interference_map[i].size() > interference_map[j].size();
71   });
72 
73   const int64 kColorUnassigned = -1;
74   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
75   for (int64 node : nodes) {
76     // Mark the colors that are already assigned to the neighbors.
77     std::vector<bool> available_colors(node_count, true);
78     for (int64 neighbor : interference_map[node]) {
79       int64 color = assigned_colors[neighbor];
80       if (color != kColorUnassigned) {
81         available_colors[color] = false;
82       }
83     }
84 
85     // Find the color that is not yet assigned to the neighbors.
86     int64 color = kColorUnassigned;
87     for (color = 0; color < available_colors.size(); ++color) {
88       if (available_colors[color]) {
89         break;
90       }
91     }
92     CHECK_NE(color, kColorUnassigned);
93     assigned_colors[node] = color;
94   }
95   return assigned_colors;
96 }
97 
98 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
99 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)100 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
101   for (const HloValue* value : buffer.values()) {
102     const HloInstruction* instruction = value->instruction();
103     if (instruction->opcode() == HloOpcode::kConstant) {
104       return true;
105     }
106     const HloModule* module = instruction->parent()->parent();
107     const bool is_entry_parameter =
108         instruction->opcode() == HloOpcode::kParameter &&
109         instruction->parent() == module->entry_computation();
110 
111     if (is_entry_parameter) {
112       bool parameter_has_alias =
113           module->input_output_alias_config().ParameterHasAlias(
114               instruction->parameter_number(), value->index());
115       // The parameter doesn't have an alias, it must be read-only.
116       if (!parameter_has_alias) {
117         return true;
118       }
119     }
120   }
121   return false;
122 }
123 
124 }  // namespace
125 
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)126 Status GatherComputationsByAllocationType(
127     const HloModule* module,
128     std::vector<const HloComputation*>* thread_local_computations,
129     std::vector<const HloComputation*>* global_computations) {
130   // Create a worklist of computations paired with whether the allocation must
131   // be thread-local.
132   std::deque<std::pair<const HloComputation*, bool>> worklist;
133   worklist.push_back(std::make_pair(module->entry_computation(),
134                                     /*is_thread_local*/ false));
135 
136   // Sets for quickly checking membership. Computations are returned in vectors
137   // for stable iteration.
138   flat_hash_set<const HloComputation*> thread_local_set;
139   flat_hash_set<const HloComputation*> global_set;
140 
141   while (!worklist.empty()) {
142     auto worklist_front = worklist.front();
143     worklist.pop_front();
144     const HloComputation* computation = worklist_front.first;
145     bool is_thread_local = worklist_front.second;
146     bool in_thread_local_set = thread_local_set.contains(computation);
147     bool in_global_set = global_set.contains(computation);
148 
149     // If the computation has already been added to the respective set, then
150     // nothing to do.
151     if ((is_thread_local && in_thread_local_set) ||
152         (!is_thread_local && in_global_set)) {
153       continue;
154     }
155 
156     // If the computation has already been added to the other set this is an
157     // error condition because the global call to the computation (eg,
158     // while/call) may return a reference to one of the thread-local buffers to
159     // the calling computation which will become a dangling reference when the
160     // thread-local is deallocated with the call return.
161     if ((is_thread_local && in_global_set) ||
162         (!is_thread_local && in_thread_local_set)) {
163       return InvalidArgument(
164           "computation %s has conflicting allocation requirements (global "
165           "and thread-local)",
166           computation->name());
167     }
168 
169     if (is_thread_local) {
170       thread_local_set.insert(computation);
171     } else {
172       global_set.insert(computation);
173     }
174 
175     for (auto* instruction : computation->instructions()) {
176       for (HloComputation* subcomputation :
177            instruction->called_computations()) {
178         switch (instruction->opcode()) {
179           case HloOpcode::kCall:
180           case HloOpcode::kConditional:
181           case HloOpcode::kWhile:
182             // Call and while must be called from a computation with global
183             // allocations as they may return references to buffers inside the
184             // called computation which cannot be thread-local.
185             if (is_thread_local) {
186               return InvalidArgument(
187                   "computation %s cannot contain call/while op because it "
188                   "requires thread-local buffer allocations",
189                   computation->name());
190             }
191             worklist.push_back(std::make_pair(subcomputation,
192                                               false));  // Not thread local.
193             break;
194           case HloOpcode::kAllReduce:
195           case HloOpcode::kMap:
196           case HloOpcode::kReduce:
197           case HloOpcode::kReduceWindow:
198           case HloOpcode::kScatter:
199           case HloOpcode::kSelectAndScatter:
200           case HloOpcode::kSort:
201           case HloOpcode::kFusion:
202             // Map/reduce etc computations are always thread-local.
203             worklist.push_back(std::make_pair(subcomputation,
204                                               true));  // Thread local.
205             break;
206           default:
207             return InternalError("Unexpected calling opcode: %s",
208                                  HloOpcodeString(instruction->opcode()));
209         }
210       }
211     }
212   }
213 
214   // Add the computations to the vectors in post order.
215   for (auto* computation : module->MakeComputationPostOrder()) {
216     if (thread_local_set.contains(computation)) {
217       thread_local_computations->push_back(computation);
218     } else if (global_set.contains(computation)) {
219       global_computations->push_back(computation);
220     }
221     // If the computation is not reachable from the entry computation, then it
222     // will not appear in either thread_local_set or global_set. We don't bother
223     // assigning buffers for these.
224   }
225   return Status::OK();
226 }
227 
ToString() const228 string BufferAllocation::Slice::ToString() const {
229   return absl::StrCat("{index:", index(), ", offset:", offset_,
230                       ", size:", size_, "}");
231 }
232 
GetSlice(const HloValue & buffer) const233 BufferAllocation::Slice BufferAllocation::GetSlice(
234     const HloValue& buffer) const {
235   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
236   return Slice(this, os.offset, os.size);
237 }
238 
AddAssignment(const HloValue & buffer,int64 offset,int64 size)239 void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
240                                      int64 size) {
241   VLOG(4) << "Adding the following buffer to allocation #" << index()
242           << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
243                              buffer.ToShortString());
244   CHECK(!assigned_buffers_.contains(&buffer))
245       << "LogicalBuffer " << buffer << " already assigned to allocation "
246       << index_;
247   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
248                           << " offset out of range";
249   CHECK_LE(offset + size, size_)
250       << "LogicalBuffer " << buffer
251       << " size out of range at offset: " << offset << " with size: " << size;
252   CHECK_EQ(buffer.color(), color())
253       << "Buffer color " << buffer.color() << " for buffer " << buffer
254       << " does not match allocation color " << color() << ".";
255   OffsetSize offset_size;
256   offset_size.offset = offset;
257   offset_size.size = size;
258   assigned_buffers_.emplace(&buffer, offset_size);
259   // For debugging purposes, store the assigned memory space in the
260   // instruction's layout.
261   for (HloPosition position : buffer.positions()) {
262     Shape* shape = ShapeUtil::GetMutableSubshape(
263         position.instruction->mutable_shape(), position.index);
264     if (shape->has_layout()) {
265       shape->mutable_layout()->set_memory_space(buffer.color());
266     }
267   }
268 }
269 
ToProto() const270 BufferAllocationProto BufferAllocation::ToProto() const {
271   BufferAllocationProto proto;
272   proto.set_index(index_);
273   proto.set_size(size_);
274   proto.set_is_thread_local(is_thread_local_);
275   proto.set_is_tuple(is_tuple_);
276   proto.set_color(color_);
277   if (is_entry_computation_parameter_) {
278     proto.set_is_entry_computation_parameter(true);
279     for (int64 idx : param_shape_index()) {
280       proto.add_parameter_shape_index(idx);
281     }
282     proto.set_parameter_number(parameter_number_);
283   }
284   proto.set_is_constant(is_constant_);
285   proto.set_maybe_live_out(maybe_live_out_);
286   for (const auto& buffer_offset_size : assigned_buffers_) {
287     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
288     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
289     proto_assigned->set_offset(buffer_offset_size.second.offset);
290     proto_assigned->set_size(buffer_offset_size.second.size);
291   }
292   absl::c_sort(*proto.mutable_assigned(),
293                [](const BufferAllocationProto::Assigned& assign1,
294                   const BufferAllocationProto::Assigned& assign2) {
295                  return assign1.logical_buffer_id() <
296                         assign2.logical_buffer_id();
297                });
298   return proto;
299 }
300 
CompareHloValuesById(const HloValue * a,const HloValue * b)301 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
302   return a->id() < b->id();
303 }
304 
305 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)306 static const HloInstruction* GetEntryParameterInstruction(
307     const BufferAllocation& alloc) {
308   for (const auto& p : alloc.assigned_buffers()) {
309     const HloValue* value = p.first;
310     const HloInstruction* instr = value->instruction();
311     if (instr->opcode() == HloOpcode::kParameter &&
312         instr->parent() == instr->parent()->parent()->entry_computation()) {
313       return instr;
314     }
315   }
316   return nullptr;
317 }
318 
319 // Returns root module output instruction corresponding to the allocation or
320 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)321 static const HloInstruction* GetOutputInstruction(
322     const BufferAllocation& alloc) {
323   for (const auto& p : alloc.assigned_buffers()) {
324     const HloValue* value = p.first;
325     for (const HloPosition& position : value->positions()) {
326       const HloInstruction* instr = position.instruction;
327       if (position.index.empty() &&
328           instr->parent()->root_instruction() == instr &&
329           instr->parent()->IsEntryComputation()) {
330         return instr;
331       }
332     }
333   }
334   return nullptr;
335 }
336 
ToString() const337 string BufferAllocation::ToString() const {
338   string output;
339   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
340   if (color() != 0) {
341     StrAppend(&output, ", color ", color());
342   }
343   if (is_entry_computation_parameter()) {
344     const HloInstruction* param = GetEntryParameterInstruction(*this);
345     CHECK(param);
346     StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
347               param->shape().ToString(/*print_layout=*/false),
348               "| at ShapeIndex ", param_shape_index().ToString());
349   }
350   if (const HloInstruction* instr = GetOutputInstruction(*this)) {
351     StrAppend(&output, ", output shape is |",
352               instr->shape().ToString(/*print_layout=*/false), "|");
353   }
354   if (is_constant()) {
355     StrAppend(&output, ", constant");
356   }
357   if (is_thread_local()) {
358     StrAppend(&output, ", thread-local");
359   }
360   if (maybe_live_out()) {
361     StrAppend(&output, ", maybe-live-out");
362   }
363   if (IsPreallocatedTempBuffer()) {
364     StrAppend(&output, ", preallocated-temp");
365   }
366   StrAppend(&output, ":\n");
367   // Dump the assigned buffers ordered by id.
368   std::vector<const HloValue*> sorted_buffers;
369   for (const auto& buffer_offset_size : assigned_buffers_) {
370     sorted_buffers.push_back(buffer_offset_size.first);
371   }
372   absl::c_sort(sorted_buffers, &CompareHloValuesById);
373   for (const HloValue* buffer : sorted_buffers) {
374     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
375     StrAppend(&output,
376               absl::StrFormat(
377                   " value: %s (size=%d,offset=%d): %s\n",
378                   buffer->ToShortString(), offset_size.size, offset_size.offset,
379                   ShapeUtil::HumanStringWithLayout(buffer->shape())));
380   }
381   return output;
382 }
383 
operator <<(std::ostream & out,const BufferAllocation & buffer)384 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
385   out << buffer.ToString();
386   return out;
387 }
388 
operator <<(std::ostream & out,const BufferAllocation::Slice & s)389 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
390   out << s.ToString();
391   return out;
392 }
393 
HasAllocation(const HloValue & value) const394 bool BufferAssignment::HasAllocation(const HloValue& value) const {
395   return allocation_index_for_value_.contains(&value);
396 }
397 
HasAllocation(const HloBuffer & buffer) const398 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
399   return allocation_index_for_value_.contains(buffer.values()[0]);
400 }
401 
GetAssignedAllocation(const HloValue & value) const402 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
403     const HloValue& value) const {
404   CHECK(HasAllocation(value));
405   return GetAllocation(allocation_index_for_value_.at(&value));
406 }
407 
GetAssignedAllocation(const HloBuffer & hlo_buffer) const408 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
409     const HloBuffer& hlo_buffer) const {
410   return GetAssignedAllocation(*hlo_buffer.values()[0]);
411 }
412 
GetMutableAssignedAllocation(const HloBuffer & buffer)413 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
414     const HloBuffer& buffer) {
415   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
416 }
417 
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const418 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
419     const HloInstruction* instruction, const ShapeIndex& index) const {
420   std::set<BufferAllocation::Slice> result;
421   for (const HloValue* value :
422        dataflow_analysis().GetValueSet(instruction, index).values()) {
423     if (HasAllocation(*value)) {
424       result.insert(GetAssignedAllocation(*value).GetSlice(*value));
425     }
426   }
427   return result;
428 }
429 
GetAllocation(BufferAllocation::Index index) const430 const BufferAllocation& BufferAssignment::GetAllocation(
431     BufferAllocation::Index index) const {
432   CHECK_GE(index, 0);
433   CHECK_LT(index, allocations_.size());
434   return allocations_[index];
435 }
436 
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const437 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
438     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
439   const HloValue* value =
440       dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
441 
442   if (!HasAllocation(*value)) {
443     return nullptr;
444   }
445 
446   const BufferAllocation& instruction_allocation =
447       GetAssignedAllocation(*value);
448   return &instruction_allocation;
449 }
450 
GetMutableAllocation(BufferAllocation::Index index)451 BufferAllocation* BufferAssignment::GetMutableAllocation(
452     BufferAllocation::Index index) {
453   return const_cast<BufferAllocation*>(&GetAllocation(index));
454 }
455 
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const456 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
457                                        const ShapeIndex& index) const {
458   for (const HloValue* value :
459        dataflow_analysis().GetValueSet(instruction, index).values()) {
460     if (allocation_index_for_value_.contains(value)) {
461       return true;
462     }
463   }
464   return false;
465 }
466 
HasTopLevelAllocation(const HloInstruction * instruction) const467 bool BufferAssignment::HasTopLevelAllocation(
468     const HloInstruction* instruction) const {
469   return HasAllocationAt(instruction, /*index=*/{});
470 }
471 
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const472 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
473     const HloInstruction* instruction, const ShapeIndex& index) const {
474   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
475           << index << "]";
476   BufferAllocation::Slice result;
477   for (const HloValue* value :
478        dataflow_analysis().GetValueSet(instruction, index).values()) {
479     VLOG(3) << "Examining value " << *value;
480     if (HasAllocation(*value)) {
481       VLOG(3) << "Has allocation";
482       const BufferAllocation::Slice slice =
483           GetAssignedAllocation(*value).GetSlice(*value);
484       if (result.allocation() == nullptr) {
485         result = slice;
486       } else if (result != slice) {
487         return FailedPrecondition(
488             "BufferAllocation::Slice for instruction %s at index %s cannot "
489             "be determined at compile-time.",
490             instruction->name(), index.ToString());
491       }
492     } else {
493       VLOG(3) << "No allocation";
494     }
495   }
496   if (result.allocation() == nullptr) {
497     return FailedPrecondition(
498         "BufferAllocation::Slice not assigned for instruction %s at index %s",
499         instruction->name(), index.ToString());
500   }
501   return result;
502 }
503 
GetUniqueTopLevelSlice(const HloInstruction * instruction) const504 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
505     const HloInstruction* instruction) const {
506   return GetUniqueSlice(instruction, /*index=*/{});
507 }
508 
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const509 bool BufferAssignment::SharesSliceAtIndex(
510     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
511     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
512   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
513          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
514 }
515 
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const516 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
517                                           const HloInstruction* hlo_b) const {
518   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
519   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
520   // assigned slice, returns the empty set.
521   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
522     SliceSet slices;
523     Status status = ShapeUtil::ForEachSubshapeWithStatus(
524         instr->shape(),
525         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
526           auto shape_slices = GetAllSlices(instr, index);
527           if (shape_slices.empty()) {
528             return InvalidArgument("No slices assigned to part of instr.");
529           }
530           slices.insert(shape_slices.begin(), shape_slices.end());
531           return Status::OK();
532         });
533     if (!status.ok()) {
534       return {};
535     }
536     return slices;
537   };
538 
539   SliceSet slices_a = collect_slices(hlo_a);
540   SliceSet slices_b = collect_slices(hlo_b);
541   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
542   // didn't return the empty set) for both HLOs, and the two resulting sets of
543   // slices are disjoint.
544   return !slices_a.empty() && !slices_b.empty() &&
545          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
546            return slices_b.contains(slice);
547          });
548 }
549 
550 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const551 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
552   return GetUniqueTopLevelSlice(
553       module_->entry_computation()->root_instruction());
554 }
555 
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)556 BufferAllocation* BufferAssignment::NewEmptyAllocation(
557     int64 size, LogicalBuffer::Color color) {
558   BufferAllocation::Index index = allocations_.size();
559   allocations_.emplace_back(index, size, color);
560   BufferAllocation* allocation = &allocations_.back();
561   return allocation;
562 }
563 
NewAllocation(const HloBuffer & buffer,int64 size)564 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
565                                                   int64 size) {
566   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
567   AddAssignment(allocation, buffer, /*offset=*/0, size);
568   allocation->peak_buffers_.push_back(buffer.values()[0]);
569   return allocation;
570 }
571 
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64 offset,int64 size)572 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
573                                      const HloBuffer& buffer, int64 offset,
574                                      int64 size) {
575   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
576       << "Non-reusable allocation already assigned a buffer: "
577       << allocation->ToString();
578 
579   for (const HloValue* buffer_value : buffer.values()) {
580     CHECK(!allocation_index_for_value_.contains(buffer_value))
581         << "BufferValue " << buffer_value << " already has an allocation.";
582     allocation->AddAssignment(*buffer_value, offset, size);
583     allocation_index_for_value_[buffer_value] = allocation->index();
584   }
585 
586   if (alias_analysis().BufferLivesOut(buffer)) {
587     VLOG(3) << "HloBuffer lives out" << buffer.ToString();
588     VLOG(3) << "Set maybe live out: " << allocation->ToString();
589     allocation->set_maybe_live_out(true);
590   }
591 }
592 
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64 offset,int64 size)593 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
594                                      const HloValue& value, int64 offset,
595                                      int64 size) {
596   allocation->AddAssignment(value, offset, size);
597   allocation_index_for_value_[&value] = allocation->index();
598   const HloValue& hlo_value =
599       *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
600   if (alias_analysis().ValueLivesOut(hlo_value)) {
601     VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
602     VLOG(3) << "Set maybe live out: " << allocation->ToString();
603     allocation->set_maybe_live_out(true);
604   }
605 }
606 
607 // Combines allocations of temporary buffers of the same color into one big
608 // BufferAllocation.
CombineTempAllocations()609 void BufferAssignment::CombineTempAllocations() {
610   VLOG(1) << "CombineTempAllocations()";
611   // Stores the combined allocations.
612   std::deque<BufferAllocation> combined_allocations;
613   // Holds the pointer to a combined allocation of each color, if any.
614   flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
615 
616   // Move all temp allocations into a single run at the end of the allocations
617   // vector.
618   const auto first_temp_it =
619       std::partition(allocations_.begin(), allocations_.end(),
620                      [](const BufferAllocation& allocation) {
621                        return !allocation.IsPreallocatedTempBuffer();
622                      });
623 
624   // Walk over the run of temp allocations, collecting the allocations belonging
625   // to the same color.
626   if (first_temp_it != allocations_.end()) {
627     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
628       BufferAllocation& temp_allocation = *it;
629       BufferValue::Color color = temp_allocation.color();
630       auto combined_it = combined_allocation_map.find(color);
631       if (combined_it == combined_allocation_map.end()) {
632         // We have found the first temp allocation of this color. Collect
633         // the other temp allocations of the same color into it subject to the
634         // size constraint.
635         VLOG(1) << "Combined temp allocation for color " << color
636                 << " is: " << temp_allocation;
637         combined_allocations.emplace_back(temp_allocation);
638         combined_allocation_map.emplace(color, &combined_allocations.back());
639         continue;
640       }
641       if (combined_it->second->size() + it->size() >=
642           multiheap_size_constraint_per_heap_) {
643         // We cannot put more into the current combined_it. So, appoint a new
644         // combined_it.
645         VLOG(1) << "Due to size constraint, reset temp allocation for color "
646                 << color << " to: " << temp_allocation;
647         combined_allocations.emplace_back(temp_allocation);
648         combined_allocation_map.emplace(color, &combined_allocations.back());
649         continue;
650       }
651 
652       BufferAllocation* combined_allocation = combined_it->second;
653       VLOG(1) << "Combined allocation absorbing temp allocation: "
654               << temp_allocation;
655 
656       // Each temp allocation is placed end-to-end, accounting for alignment.
657       // The offset of each buffer in the combined allocation is computed from
658       // the base offset of the allocation.
659       int64 alignment = color_alignment_(color);
660       const int64 base =
661           RoundUpToNearest(combined_allocation->size(), alignment);
662       combined_allocation->set_size(base + temp_allocation.size());
663       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
664         const HloValue* value = buffer_offset_size.first;
665         const int64 offset = buffer_offset_size.second.offset;
666         const int64 size = buffer_offset_size.second.size;
667         combined_allocation->AddAssignment(*value, base + offset, size);
668       }
669       if (!temp_allocation.HeapTraces().empty()) {
670         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
671         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
672       }
673 
674       combined_allocation->peak_buffers_.insert(
675           combined_allocation->peak_buffers_.end(),
676           temp_allocation.peak_buffers_.begin(),
677           temp_allocation.peak_buffers_.end());
678     }
679     // Replace all existing temporary allocations with the new combined
680     // allocations.
681     allocations_.erase(first_temp_it, allocations_.end());
682     for (BufferAllocation& combined : combined_allocations) {
683       temp_allocation_total_size_ += combined.size();
684       allocations_.push_back(std::move(combined));
685     }
686   }
687 
688   // Update allocation indices to their new positions.
689   allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
690                                     allocation_index_for_value_.end());
691   for (size_t index = 0; index < allocations_.size(); ++index) {
692     BufferAllocation* allocation = &allocations_[index];
693     allocation->set_index(index);
694     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
695       const HloValue* value = buffer_offset_size.first;
696       allocation_index_for_value_[value] = index;
697     }
698   }
699 }
700 
ComputeSummaryStats()701 Status BufferAssignment::ComputeSummaryStats() {
702   for (auto& allocation : Allocations()) {
703     if (allocation.is_entry_computation_parameter()) {
704       stats_.parameter_allocation_count++;
705       stats_.parameter_allocation_bytes += allocation.size();
706     }
707     if (allocation.is_constant()) {
708       stats_.constant_allocation_count++;
709       stats_.constant_allocation_bytes += allocation.size();
710     }
711     if (allocation.maybe_live_out()) {
712       stats_.maybe_live_out_allocation_count++;
713       stats_.maybe_live_out_allocation_bytes += allocation.size();
714     }
715     if (allocation.IsPreallocatedTempBuffer()) {
716       stats_.preallocated_temp_allocation_count++;
717       stats_.preallocated_temp_allocation_bytes += allocation.size();
718     }
719     stats_.total_allocation_count++;
720     stats_.total_allocation_bytes += allocation.size();
721   }
722 
723   // Only compute total fragmentation if all computations have schedules.
724   HloSchedule schedule(module_);
725   bool schedule_complete = true;
726   for (const auto& computation : module_->computations()) {
727     if (!computation->IsFusionComputation()) {
728       const HloInstructionSequence* sequence =
729           hlo_ordering().SequentialOrder(*computation);
730       if (sequence == nullptr) {
731         schedule_complete = false;
732       } else {
733         schedule.set_sequence(computation, *sequence);
734       }
735     }
736   }
737   if (schedule_complete) {
738     TF_RETURN_IF_ERROR(schedule.Verify());
739     TF_ASSIGN_OR_RETURN(
740         const int64 min_size,
741         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
742     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
743   }
744 
745   return Status::OK();
746 }
747 
ToString() const748 string BufferAssignment::Stats::ToString() const {
749   string s;
750   StrAppendFormat(&s, "BufferAssignment stats:\n");
751   StrAppendFormat(&s, "             parameter allocation: %10s\n",
752                   HumanReadableNumBytes(parameter_allocation_bytes));
753   StrAppendFormat(&s, "              constant allocation: %10s\n",
754                   HumanReadableNumBytes(constant_allocation_bytes));
755   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
756                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
757   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
758                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
759   if (preallocated_temp_fragmentation_bytes >= 0) {
760     const double percent = 100. * preallocated_temp_fragmentation_bytes /
761                            preallocated_temp_allocation_bytes;
762     StrAppendFormat(
763         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
764         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
765   }
766   StrAppendFormat(&s, "                 total allocation: %10s\n",
767                   HumanReadableNumBytes(total_allocation_bytes));
768   if (total_fragmentation_bytes >= 0) {
769     const double percent =
770         100. * total_fragmentation_bytes / total_allocation_bytes;
771     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
772                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
773   }
774   return s;
775 }
776 
ToString() const777 string BufferAssignment::ToString() const {
778   string output;
779   absl::StrAppend(&output, "BufferAssignment:\n");
780   std::vector<const HloValue*> used_values;
781   int64 total_size = 0;
782   for (auto& allocation : allocations_) {
783     total_size += allocation.size();
784     absl::StrAppend(&output, allocation.ToString());
785     for (const auto& p : allocation.assigned_buffers()) {
786       used_values.push_back(p.first);
787     }
788   }
789   absl::StrAppend(&output, "\nTotal bytes used: ", total_size, "\n");
790   absl::StrAppend(&output, "\nUsed values:\n");
791   absl::c_sort(used_values, &CompareHloValuesById);
792   for (const HloValue* value : used_values) {
793     absl::StrAppend(&output, value->ToString());
794   }
795   return output;
796 }
797 
BufferInfoString() const798 string BufferAssignment::BufferInfoString() const {
799   string binfo;
800   // Columns in buffer information:
801   // buffer_id: int. This value can be used to match the allocation in
802   // allocation information.
803   // buffer_name: string.
804   // offset: int. Starting position of the buffer in the memory space.
805   // size: int. Size of the buffer in bytes.
806   // definition_time: int. Position in the schedule where the buffer starts
807   // being live (inclusive).
808   // end_time: int. Position in the schedule where the buffer stops being live
809   // (exclusive).
810   // num_uses: int. Number of uses of the buffer.
811   // use_names: string. This is a semicolon-separated list of string
812   // representation of uses.
813   // Append the column names.
814   absl::StrAppend(&binfo,
815                   "buffer_id,buffer_name,offset,size,"
816                   "definition_time,end_time,num_uses,use_times,use_names\n");
817   const HloLiveRange& live_ranges = hlo_live_range();
818   const auto& instruction_schedule = live_ranges.instruction_schedule();
819   const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
820   // Sort the buffers by Id.
821   std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
822   for (const BufferAllocation& allocation : allocations_) {
823     absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
824   }
825   absl::c_sort(
826       buffers,
827       [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
828          const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
829         return b1.first->id() < b2.first->id();
830       });
831   for (const auto& buffer_pair : buffers) {
832     const HloValue& buffer = *buffer_pair.first;
833     const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
834     if (!buffer_live_ranges.contains(&buffer)) {
835       continue;
836     }
837     // Ordering uses by their use position.
838     std::vector<std::pair<int64, std::string>> uses;
839     uses.reserve(buffer.uses().size());
840     for (const HloUse& use : buffer.uses()) {
841       uses.emplace_back(instruction_schedule.at(use.instruction),
842                         use.ToString());
843     }
844     absl::c_sort(uses);
845     std::vector<int64> use_positions;
846     std::vector<std::string> use_names;
847     use_positions.reserve(uses.size());
848     use_names.reserve(uses.size());
849     for (const auto& use : uses) {
850       use_positions.push_back(use.first);
851       use_names.push_back(use.second);
852     }
853     const int64 definition_time =
854         instruction_schedule.at(buffer.defining_position().instruction);
855     const int64 end_t = buffer_live_ranges.at(&buffer).end;
856     absl::StrAppend(&binfo, buffer.id(), ",");
857     absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
858     absl::StrAppend(&binfo, offset_size.offset, ",");
859     absl::StrAppend(&binfo, offset_size.size, ",");
860     absl::StrAppend(&binfo, definition_time, ",");
861     absl::StrAppend(&binfo, end_t, ",");
862     absl::StrAppend(&binfo, use_positions.size(), ",");
863     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
864     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
865     absl::StrAppend(&binfo, "\n");
866   }
867   return binfo;
868 }
869 
ToProto() const870 BufferAssignmentProto BufferAssignment::ToProto() const {
871   BufferAssignmentProto proto;
872   // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
873   // because we need to do the HasAllocation check for each buffer. Otherwise
874   // the buffer_size_ call might fail for some backends.
875   const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
876   for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
877     auto& value = dataflow.values().at(id);
878     if (HasAllocation(*value)) {
879       LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
880       proto.add_logical_buffers()->Swap(&proto_buffer);
881 
882       // Fill buffer aliases.
883       for (const HloValue* alias :
884            alias_analysis().GetBufferContainingValue(*value).values()) {
885         if (alias->instruction() == value->instruction() &&
886             alias->index() == value->index()) {
887           continue;  // skip self-aliases
888         }
889         BufferAssignmentProto::BufferAlias* proto_alias =
890             proto.add_buffer_aliases();
891         LogicalBufferProto::Location proto_alias_location =
892             BufferValue::ToLocationProto(*alias->instruction(), alias->index());
893         proto_alias->set_source_buffer_id(value->id());
894         proto_alias->mutable_location()->Swap(&proto_alias_location);
895       }
896     }
897   }
898   for (const BufferAllocation& allocation : Allocations()) {
899     BufferAllocationProto proto_allocation = allocation.ToProto();
900     proto.add_buffer_allocations()->Swap(&proto_allocation);
901     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
902       *proto.add_heap_simulator_traces() = heap_trace;
903     }
904   }
905   return proto;
906 }
907 
908 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)909 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
910     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
911     BufferValue::SizeFunction buffer_size,
912     LogicalBuffer::AlignmentFunction color_alignment,
913     bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
914     const absl::flat_hash_set<HloOpcode>& must_not_live_out,
915     HloDataflowAnalysis::CanShareBuffer can_share_buffer,
916     std::unique_ptr<PresetAssignments> preset_assignments) {
917   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
918                           must_not_live_out, std::move(preset_assignments));
919   return assigner.CreateAssignment(
920       module, std::move(hlo_ordering), std::move(buffer_size),
921       std::move(color_alignment), std::move(can_share_buffer));
922 }
923 
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)924 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
925                                          const HloValue* buffer2,
926                                          BufferAssignment* assignment) {
927   CHECK((assignment->hlo_live_range().total_order_scheduled()));
928   const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
929 
930   const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
931 
932   CHECK(buffer_live_ranges.contains(buffer1))
933       << "Buffer doesn't have a proper live range:" << buffer1;
934 
935   CHECK(buffer_live_ranges.contains(buffer2))
936       << "Buffer doesn't have a proper live range:" << buffer2;
937 
938   // Check if a user value can share the same buffer as its operand.
939   auto can_share_as_operand = [&assignment](const HloValue* user_value,
940                                             const HloValue* operand_value) {
941     return user_value->instruction()->IsUserOf(operand_value->instruction()) &&
942            assignment->dataflow_analysis().CanShareOperandBufferWithUser(
943                operand_value->instruction(), operand_value->index(),
944                user_value->instruction(), user_value->index()) &&
945            user_value->instruction()->opcode() != HloOpcode::kCopy;
946   };
947 
948   auto live_range_1 = buffer_live_ranges.at(buffer1);
949   auto live_range_2 = buffer_live_ranges.at(buffer2);
950 
951   if (!(live_range_1.start > live_range_2.end ||
952         live_range_2.start > live_range_1.end)) {
953     if (live_range_1.end == live_range_2.start) {
954       auto operand_value = buffer1;
955       auto user_value = buffer2;
956       if (!can_share_as_operand(user_value, operand_value)) {
957         VLOG(4) << "End of live range of " << buffer1->ToShortString()
958                 << " is equal to the start of live range of "
959                 << buffer2->ToShortString() << ", buffer cannot be shared.";
960         return true;
961       }
962     } else if (live_range_2.end == live_range_1.start) {
963       auto operand_value = buffer2;
964       auto user_value = buffer1;
965       if (!can_share_as_operand(user_value, operand_value)) {
966         VLOG(4) << "End of live range of " << buffer2->ToShortString()
967                 << " is equal to the start of live range of "
968                 << buffer1->ToShortString() << ", buffer cannot be shared.";
969         return true;
970       }
971     } else {
972       VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
973               << *buffer2;
974       VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
975       VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
976       VLOG(4) << "live_range_2.start" << live_range_2.start;
977       VLOG(4) << "live_range_2.end" << live_range_2.end;
978       return true;
979     }
980   }
981   return false;
982 }
983 
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)984 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
985                                        const HloBuffer& hlo_buffer,
986                                        BufferAssignment* assignment) {
987   CHECK(!assignment->HasAllocation(hlo_buffer))
988       << "buffer " << hlo_buffer << " already has an allocation assigned.";
989 
990   VLOG(4) << "Trying to assign " << hlo_buffer << " size "
991           << assignment->HloBufferSize(hlo_buffer)
992           << " to allocation: " << *allocation;
993 
994   if (hlo_buffer.color() != allocation->color()) {
995     VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
996             << " and allocation has color " << allocation->color() << ".";
997     return false;
998   }
999 
1000   if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1001     VLOG(4) << "Can't assign: buffer is larger than allocation ("
1002             << assignment->HloBufferSize(hlo_buffer) << " > "
1003             << allocation->size() << ")";
1004     return false;
1005   }
1006 
1007   if (allocation->is_readonly()) {
1008     VLOG(4) << "Can't assign: allocation is readonly";
1009     return false;
1010   }
1011 
1012   if (!must_not_live_out_.empty()) {
1013     if (allocation->maybe_live_out()) {
1014       // If a buffer maybe live out, the allocation cannot contain any node from
1015       // the "must_not_live_out_" set.
1016       for (const HloValue* value : hlo_buffer.values()) {
1017         if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
1018           VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1019                   << " cannot live out of the module";
1020           return false;
1021         }
1022       }
1023     }
1024     // The above check is not enough -- There could be the case where an
1025     // allocation can be not live out and contains an instruction with opcode
1026     // from the "must_not_live_out_" set, but assigning a live out buffer to
1027     // that allocation makes the allocation live out and also contains
1028     // instruction from the "must_not_live_out_" set.
1029     if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1030       for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1031         if (must_not_live_out_.count(
1032                 buffer_offset_size.first->instruction()->opcode()) > 0) {
1033           VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1034                   << " cannot live out of the module";
1035           return false;
1036         }
1037       }
1038     }
1039   }
1040 
1041   if (!allocation->is_reusable()) {
1042     VLOG(4) << "Can't assign: allocation is not reusable";
1043     return false;
1044   }
1045 
1046   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1047     // Pairwise compare.
1048     const HloValue& assigned_buffer =
1049         *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1050     for (const HloValue* new_value : hlo_buffer.values()) {
1051       if (assignment->hlo_live_range().total_order_scheduled()) {
1052         if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1053           VLOG(4) << "Can't assign: assignee " << assigned_buffer
1054                   << " live range interferes with "
1055                   << new_value->ToShortString();
1056           return false;
1057         }
1058       } else if (assignment->hlo_ordering().MayInterfere(
1059                      assigned_buffer, *new_value,
1060                      assignment->dataflow_analysis())) {
1061         // Fallback to partial order based interference detection (slower) when
1062         // we don't have a total order scheduled module.
1063         VLOG(4) << "Can't assign: assignee " << assigned_buffer
1064                 << " may interfere with " << new_value->ToShortString();
1065         return false;
1066       }
1067 
1068       for (const HloPosition& assigned_buffer_position :
1069            assigned_buffer.positions()) {
1070         // Copy instruction don't share a buffer with their input operand.
1071         if (new_value->instruction()->IsUserOf(
1072                 assigned_buffer_position.instruction) &&
1073             new_value->instruction()->opcode() == HloOpcode::kCopy) {
1074           VLOG(4) << "Can't assign: assignee " << assigned_buffer
1075                   << " is used at copy instruction "
1076                   << new_value->ToShortString();
1077           return false;
1078         }
1079       }
1080     }
1081   }
1082 
1083   // If the buffer is live out of the computation then it should only be
1084   // assigned a buffer which exactly fits the result to avoid wasting memory
1085   // (result buffers can have arbitrary lifetimes).
1086   if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1087       allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1088     VLOG(4) << "Can't assign: buffer " << hlo_buffer
1089             << "is live out and size not the same as allocation";
1090     return false;
1091   }
1092 
1093   assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1094                             assignment->HloBufferSize(hlo_buffer));
1095   return true;
1096 }  // namespace xla
1097 
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1098 Status BufferAssigner::AssignSingleHloBuffer(
1099     const HloBuffer* hlo_buffer, bool is_thread_local,
1100     absl::flat_hash_map<const HloComputation*,
1101                         absl::flat_hash_set<const HloValue*>>*
1102         buffers_to_assign_sequentially,
1103     std::vector<BufferAllocation::Index>* allocation_indices,
1104     BufferAssignment* assignment) {
1105   const int64 buffer_size = assignment->HloBufferSize(*hlo_buffer);
1106   for (const HloValue* value : hlo_buffer->values()) {
1107     if (value->instruction()->opcode() == HloOpcode::kConstant) {
1108       if (allocate_buffers_for_constants_) {
1109         BufferAllocation* allocation =
1110             assignment->NewAllocation(*hlo_buffer, buffer_size);
1111         allocation->set_constant(true);
1112         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1113                 << *hlo_buffer << " value ptr: " << value;
1114       }
1115       VLOG(3) << "Not allocating buffer for constant";
1116       return Status::OK();
1117     }
1118 
1119     const HloInstruction* instruction = value->instruction();
1120     const bool is_entry_parameter =
1121         instruction->opcode() == HloOpcode::kParameter &&
1122         instruction->parent() ==
1123             instruction->parent()->parent()->entry_computation();
1124 
1125     if (is_entry_parameter) {
1126       bool parameter_has_alias =
1127           assignment->module().input_output_alias_config().ParameterHasAlias(
1128               instruction->parameter_number(), value->index());
1129       // If the hlo buffer is part of an external parameter, creates a new
1130       // allocation and sets its parameter number. Parameters of non-entry
1131       // computations do not need special allocations because they live inside
1132       // callers.
1133       BufferAllocation* allocation =
1134           assignment->NewAllocation(*hlo_buffer, buffer_size);
1135 
1136       allocation->set_entry_computation_parameter(
1137           instruction->parameter_number(), value->index(), parameter_has_alias);
1138       if (parameter_has_alias) {
1139         allocation_indices->push_back(allocation->index());
1140       }
1141       VLOG(3) << "New allocation #" << allocation->index()
1142               << " marked as entry computation parameter: " << *hlo_buffer;
1143       return Status::OK();
1144     }
1145   }
1146 
1147   if (is_thread_local) {
1148     BufferAllocation* allocation =
1149         assignment->NewAllocation(*hlo_buffer, buffer_size);
1150     allocation->set_is_thread_local(true);
1151     VLOG(3) << "New allocation #" << allocation->index()
1152             << " for thread-local: " << *hlo_buffer;
1153     return Status::OK();
1154   }
1155 
1156   for (const HloValue* value : hlo_buffer->values()) {
1157     if (value->shape().IsTuple()) {
1158       BufferAllocation* allocation =
1159           assignment->NewAllocation(*hlo_buffer, buffer_size);
1160       allocation->set_is_tuple(true);
1161       VLOG(3) << "New allocation #" << allocation->index()
1162               << " for tuple-shaped buffer: " << *hlo_buffer;
1163       return Status::OK();
1164     }
1165 
1166     if (value->IsTopLevel() && !value->IsTuple()) {
1167       const HloInstruction* instruction = value->instruction();
1168       for (auto* operand : instruction->operands()) {
1169         for (const auto& operand_slice :
1170              assignment->GetAllSlices(operand, /*index=*/{})) {
1171           BufferAllocation* allocation =
1172               assignment->GetMutableAllocation(operand_slice.index());
1173           if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1174             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1175                     << " for: " << *hlo_buffer;
1176             return Status::OK();
1177           }
1178         }
1179       }
1180     }
1181   }
1182 
1183   // Find the smallest buffer which can be reused iterating from end of
1184   // allocation_indices (smallest) to beginning (largest).
1185   for (int allocation_index = allocation_indices->size() - 1;
1186        allocation_index >= 0; allocation_index--) {
1187     BufferAllocation* allocation = assignment->GetMutableAllocation(
1188         allocation_indices->at(allocation_index));
1189     if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1190       VLOG(3) << "Reusing allocation #" << allocation->index()
1191               << " for: " << *hlo_buffer;
1192       return Status::OK();
1193     }
1194   }
1195 
1196   if (!assignment->HasAllocation(*hlo_buffer) &&
1197       !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1198     bool all_computations_have_sequential_order = true;
1199     for (const HloValue* hlo_value : hlo_buffer->values()) {
1200       HloComputation* computation = hlo_value->instruction()->parent();
1201       const bool has_sequential_order =
1202           assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1203       all_computations_have_sequential_order &= has_sequential_order;
1204     }
1205 
1206     if (all_computations_have_sequential_order) {
1207       for (const HloValue* hlo_value : hlo_buffer->values()) {
1208         HloComputation* computation = hlo_value->instruction()->parent();
1209         // There is a sequential instruction ordering, so we delay assignment
1210         // of temp buffers until after the loop. We do this right before we
1211         // decide to create a new allocation, to ensure we've exhausted all
1212         // the buffer re-use cases above.
1213         //
1214         // Entry parameters and thread local buffers were already handled
1215         // earlier in this loop iteration.  See
1216         // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1217         // temp buffers.
1218         (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1219         VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1220       }
1221       return Status::OK();
1222     }
1223   }
1224 
1225   if (!assignment->HasAllocation(*hlo_buffer)) {
1226     BufferAllocation* allocation =
1227         assignment->NewAllocation(*hlo_buffer, buffer_size);
1228     allocation_indices->push_back(allocation->index());
1229     VLOG(3) << "New allocation #" << allocation->index()
1230             << " for: " << *hlo_buffer;
1231   }
1232 
1233   TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1234   return Status::OK();
1235 }
1236 
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1237 Status BufferAssigner::AssignBuffersForComputations(
1238     const std::vector<const HloComputation*>& computations,
1239     bool is_thread_local,
1240     absl::flat_hash_map<const HloComputation*,
1241                         absl::flat_hash_set<const HloValue*>>*
1242         buffers_to_assign_sequentially,
1243     BufferAssignment* assignment) {
1244   if (computations.empty()) {
1245     return Status::OK();
1246   }
1247   std::vector<const HloBuffer*> sorted_buffers;
1248 
1249   // First assign the preset allocations.
1250   absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1251 
1252   TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1253 
1254   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1255 
1256   for (const HloBuffer& buffer : alias_analysis.buffers()) {
1257     // Skip if the buffer is already assigned since it had a preset allocation.
1258     if (preset_assigned_buffers.find(&buffer) !=
1259         preset_assigned_buffers.end()) {
1260       VLOG(3) << "Skip allocation for buffer: " << buffer;
1261       continue;
1262     }
1263     TF_RET_CHECK(!buffer.values().empty());
1264     const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1265     if (absl::c_linear_search(computations, comp)) {
1266       sorted_buffers.push_back(&buffer);
1267     }
1268   }
1269 
1270   // Generate a post order sort of instructions for sorting of the
1271   // HloBuffers.
1272   flat_hash_map<const HloInstruction*, int> post_order_position;
1273   int position = 0;
1274   std::vector<const HloComputation*> reverse_post_order_computations;
1275   std::unique_ptr<CallGraph> call_graph =
1276       CallGraph::Build(computations[0]->parent());
1277   TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1278     if (absl::c_linear_search(computations, node.computation())) {
1279       reverse_post_order_computations.push_back(node.computation());
1280     }
1281     return Status::OK();
1282   }));
1283   absl::c_reverse(reverse_post_order_computations);
1284   for (auto* computation : reverse_post_order_computations) {
1285     for (auto* instruction : computation->MakeInstructionPostOrder()) {
1286       post_order_position.emplace(instruction, position);
1287       position++;
1288     }
1289   }
1290 
1291   HloSchedule schedule(&assignment->module());
1292 
1293   for (const HloComputation* computation : computations) {
1294     const HloInstructionSequence* instruction_sequence =
1295         assignment->hlo_ordering().SequentialOrder(*computation);
1296     const bool has_sequential_order = instruction_sequence != nullptr;
1297     if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1298       // Every sequential computation must get an entry in the
1299       // buffers_to_assign_sequentially map, even if we end up with an empty
1300       // set of buffers. This ensures we can correctly determine whether to
1301       // run whole-module heap simulation.
1302       buffers_to_assign_sequentially->emplace(computation,
1303                                               flat_hash_set<const HloValue*>());
1304 
1305       schedule.set_sequence(computation, *instruction_sequence);
1306     }
1307   }
1308 
1309   absl::c_sort(
1310       sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1311                           const HloBuffer* a, const HloBuffer* b) {
1312         // Primary sort is by decreasing buffer size.
1313         const int64 a_size = assignment->HloBufferSize(*a);
1314         const int64 b_size = assignment->HloBufferSize(*b);
1315         if (a_size != b_size) {
1316           return a_size > b_size;  // use ">" for decreasing size.
1317         }
1318 
1319         const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1320         const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1321         if (a_live_out != b_live_out) {
1322           return a_live_out;
1323         }
1324         auto compare = [&post_order_position](const HloValue* value1,
1325                                               const HloValue* value2) {
1326           return post_order_position.at(value1->instruction()) <
1327                  post_order_position.at(value2->instruction());
1328         };
1329         const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1330         const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1331         return compare(a_min, b_min);
1332       });
1333 
1334   std::vector<BufferAllocation::Index> allocation_indices;
1335 
1336   for (const HloBuffer* buffer : sorted_buffers) {
1337     VLOG(3) << "=================================================";
1338     VLOG(3) << "Assigning buffer for " << *buffer;
1339     TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1340                                              buffers_to_assign_sequentially,
1341                                              &allocation_indices, assignment));
1342   }
1343   return Status::OK();
1344 }
1345 
1346 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1347 BufferAssigner::SplitBuffersByColor(
1348     const flat_hash_set<const HloValue*>& buffers) {
1349   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1350   for (auto buffer : buffers) {
1351     color_map[buffer->color()].insert(buffer);
1352   }
1353   return color_map;
1354 }
1355 
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1356 Status BufferAssigner::AssignPresetBuffers(
1357     absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1358     BufferAssignment* assignment) {
1359   if (!preset_assignments_) {
1360     return Status::OK();
1361   }
1362 
1363   // Create an allocation for each preset color.
1364   absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1365       preset_allocations;
1366   for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1367     LogicalBuffer::Color color(color_and_info.first);
1368     auto inserted = preset_allocations.emplace(
1369         color,
1370         assignment->NewEmptyAllocation(color_and_info.second.size, color));
1371     BufferAllocation* inserted_allocation = inserted.first->second;
1372     inserted_allocation->AddHeapTrace(
1373         color_and_info.second.heap_simulator_trace);
1374     VLOG(3) << "Created preset buffer allocation "
1375             << inserted_allocation->index()
1376             << ", color: " << inserted_allocation->color()
1377             << ", size: " << inserted_allocation->size();
1378   }
1379 
1380   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1381 
1382   for (auto& position_and_chunk : preset_assignments_->chunks()) {
1383     const HloPosition& defining_position = position_and_chunk.first;
1384     const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1385         defining_position.instruction, defining_position.index);
1386     for (const HloValue* value : buffer.values()) {
1387       VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1388       const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1389       auto preset_allocations_iter = preset_allocations.find(value->color());
1390       CHECK(preset_allocations_iter != preset_allocations.end())
1391           << "No preset value allocation for color " << value->color()
1392           << " for " << value->ToShortString() << " found.";
1393       preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1394                                                      chunk.size);
1395     }
1396 
1397     assigned_buffers->insert(&buffer);
1398   }
1399 
1400   // Upon consumption of the preset assignments, delete it so that if this
1401   // method is called again, it does not assign the same buffers multiple times.
1402   preset_assignments_ = {};
1403 
1404   return Status::OK();
1405 }
1406 
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1407 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1408     const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1409         buffers_to_assign_sequentially,
1410     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1411   // Run the sequence of instructions through the heap simulator.  The
1412   // heuristic that seems to give the best results is lazy-best-fit, with all
1413   // runs of alloc / free calls sorted in decreasing size order.
1414   const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1415 
1416   // Returns a heap algorithm that chooses the best result from several
1417   // algorithms.
1418   auto get_heap_algorithm = [&](int64 alignment) {
1419     auto algorithms = absl::make_unique<
1420         std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1421     algorithms->push_back(
1422         absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1423             assignment->multiheap_size_constraint_per_heap(), alignment,
1424             GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1425     algorithms->push_back(
1426         absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1427             assignment->multiheap_size_constraint_per_heap(), alignment,
1428             GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1429     return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1430         std::move(algorithms));
1431   };
1432 
1433   if (run_whole_module_heap_simulation) {
1434     // Run the heap simulation over the whole module. This reduces memory
1435     // usage, since buffers for kCall, kWhile, and kConditional
1436     // sub-computations are only live for the duration of their calling
1437     // instructions.
1438     VLOG(1) << "Running whole-module heap simulation";
1439     HloSchedule schedule(&assignment->module());
1440     flat_hash_set<const HloValue*> all_buffers_to_assign;
1441     for (const auto& pair : buffers_to_assign_sequentially) {
1442       const HloComputation* computation = pair.first;
1443       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1444       const HloInstructionSequence* instruction_sequence =
1445           hlo_ordering.SequentialOrder(*computation);
1446       CHECK(instruction_sequence != nullptr) << computation->name();
1447       schedule.set_sequence(computation, *instruction_sequence);
1448       all_buffers_to_assign.insert(buffers_to_assign.begin(),
1449                                    buffers_to_assign.end());
1450     }
1451     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1452     for (auto& single_colored_set : color_map) {
1453       auto color = single_colored_set.first;
1454       VLOG(2) << "Simulating heap for color " << color;
1455       int64 alignment = assignment->color_alignment_(color);
1456       HeapSimulator::Options options;
1457       options.alloc_constants = allocate_buffers_for_constants_;
1458       options.buffers_to_assign = &single_colored_set.second;
1459 
1460       TF_ASSIGN_OR_RETURN(
1461           HeapSimulator::Result<HloValue> result,
1462           HeapSimulator::Run(
1463               get_heap_algorithm(alignment), assignment->module(), schedule,
1464               assignment->alias_analysis(), assignment->buffer_size_, options));
1465       AssignBuffersFromHeapSimulator(result, assignment,
1466                                      single_colored_set.first);
1467     }
1468   } else {
1469     // Run the heap-simulation on a per-computation basis. Buffers for
1470     // sub-computations are assigned disjoint BufferAllocations, assuming the
1471     // worst-case that they may all be live concurrently.
1472     VLOG(1) << "Running per-computation heap simulation";
1473     for (const auto& pair : buffers_to_assign_sequentially) {
1474       const HloComputation* computation = pair.first;
1475       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1476       const HloInstructionSequence* instruction_sequence =
1477           hlo_ordering.SequentialOrder(*computation);
1478       CHECK(instruction_sequence != nullptr) << computation->name();
1479       auto color_map = SplitBuffersByColor(buffers_to_assign);
1480       for (auto& single_colored_set : color_map) {
1481         auto color = single_colored_set.first;
1482         VLOG(2) << "Simulating heap for color " << color;
1483         int64 alignment = assignment->color_alignment_(color);
1484         HeapSimulator::Options options;
1485         options.buffers_to_assign = &single_colored_set.second;
1486         TF_ASSIGN_OR_RETURN(
1487             HeapSimulator::Result<HloValue> result,
1488             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1489                                *instruction_sequence,
1490                                assignment->alias_analysis(),
1491                                assignment->buffer_size_, options));
1492         AssignBuffersFromHeapSimulator(result, assignment,
1493                                        single_colored_set.first);
1494       }
1495     }
1496   }
1497   return Status::OK();
1498 }
1499 
1500 namespace {
1501 // Computes and returns the set of logical buffers live at the point of
1502 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1503 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1504 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1505     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1506   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1507   // buffers in this allocation.
1508   absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1509   absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1510   for (const auto& pair : allocation.assigned_buffers()) {
1511     const HloValue* value = pair.first;
1512     const BufferAllocation::OffsetSize& offset_size = pair.second;
1513     id_to_value[value->id()] = value;
1514     buffer_sizes[value] = offset_size.size;
1515   }
1516   VLOG(1) << "Compute peak memory logical buffers";
1517 
1518   // Returns how much the given event increases the total size of live
1519   // buffers. Can be negative.
1520   auto memory_delta = [&id_to_value, &buffer_sizes](
1521                           const HeapSimulatorTrace::Event& event) -> int64 {
1522     const HloValue* buffer = id_to_value.at(event.buffer_id());
1523     const int64 buffer_size = buffer_sizes.at(buffer);
1524     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1525         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1526       return buffer_size;
1527     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1528       return -1 * buffer_size;
1529     }
1530     LOG(FATAL) << "Unknown event kind: " << event.kind();
1531   };
1532 
1533   // First compute the size of the maximal live set.
1534   int64 max_live_size = 0;
1535   int64 live_size = 0;
1536   for (const auto& event : heap_trace.events()) {
1537     if (!id_to_value.contains(event.buffer_id())) {
1538       // Skip as the buffer associated with this trace event is not placed into
1539       // this allocation. This can happen when size constraints are given to the
1540       // heap simulator.
1541       continue;
1542     }
1543     live_size += memory_delta(event);
1544     if (max_live_size < live_size) {
1545       max_live_size = live_size;
1546     }
1547   }
1548 
1549   // Next gather the set of logical buffers live at the earliest point of
1550   // maximal live set size.
1551   absl::flat_hash_set<const HloValue*> live_values;
1552   live_size = 0;
1553   for (const auto& event : heap_trace.events()) {
1554     if (!id_to_value.contains(event.buffer_id())) {
1555       // Skip as the buffer associated with this trace event is not placed into
1556       // this allocation. This can happen when size constraints are given to the
1557       // heap simulator.
1558       continue;
1559     }
1560     const HloValue* value = id_to_value.at(event.buffer_id());
1561     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1562         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1563       InsertOrDie(&live_values, value);
1564     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1565       CHECK(ContainsKey(live_values, value));
1566       live_values.erase(value);
1567     }
1568     live_size += memory_delta(event);
1569 
1570     if (live_size == max_live_size) {
1571       break;
1572     }
1573   }
1574   CHECK_EQ(live_size, max_live_size);
1575 
1576   std::vector<const HloValue*> live_values_vector;
1577   live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1578                             live_values.end());
1579 
1580   // Stabily sort the live buffers.
1581   absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1582     return a->id() < b->id();
1583   });
1584   VLOG(4) << "Peak memory buffer:";
1585   for (auto value : live_values_vector) {
1586     VLOG(4) << "  " << value->ToString();
1587   }
1588   return live_values_vector;
1589 }
1590 
1591 }  // namespace
1592 
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1593 void BufferAssigner::AssignBuffersFromHeapSimulator(
1594     const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1595     BufferValue::Color color) {
1596   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1597     assignment->stats_.preallocated_temp_fragmentation_bytes =
1598         result.fragmentation_size;
1599   } else {
1600     assignment->stats_.preallocated_temp_fragmentation_bytes +=
1601         result.fragmentation_size;
1602   }
1603   VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1604 
1605   // Iterate through heap_results. For each heap_result, create a new allocation
1606   // in `assignment`.
1607   for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1608        result.heap_results) {
1609     BufferAllocation* allocation =
1610         assignment->NewEmptyAllocation(heap_result.heap_size, color);
1611     for (const auto& buffer_chunk : heap_result.chunk_map) {
1612       const HloValue& value = *buffer_chunk.first;
1613       const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1614       assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1615     }
1616     allocation->peak_buffers_ =
1617         ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1618 
1619     XLA_VLOG_LINES(2, allocation->ToString());
1620 
1621     allocation->AddHeapTrace(result.debug_trace);
1622   }
1623 }
1624 
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1625 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1626     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1627     BufferValue::SizeFunction buffer_size,
1628     LogicalBuffer::AlignmentFunction color_alignment,
1629     HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1630   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1631                       HloAliasAnalysis::Run(module, can_share_buffer));
1632 
1633   // Set up a schedule for each computation.
1634   HloSchedule schedule(module);
1635   for (const HloComputation* computation : module->computations()) {
1636     const HloInstructionSequence* instruction_sequence =
1637         hlo_ordering->SequentialOrder(*computation);
1638     const bool has_sequential_order = instruction_sequence != nullptr;
1639     if (has_sequential_order) {
1640       schedule.set_sequence(computation, *instruction_sequence);
1641     }
1642   }
1643 
1644   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1645                       HloLiveRange::Run(schedule, *alias_analysis,
1646                                         module->entry_computation(), true));
1647 
1648   VLOG(1) << "Assigning buffers to module " << module->name();
1649   XLA_VLOG_LINES(3, module->ToString());
1650   XLA_VLOG_LINES(3, alias_analysis->ToString());
1651   XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1652   VLOG(1) << "Number of buffers to assign: "
1653           << alias_analysis->buffers().size();
1654 
1655   // Can't use absl::make_unique because BufferAssignment constructor is
1656   // private.
1657   std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1658       module, std::move(hlo_ordering), std::move(buffer_size),
1659       std::move(color_alignment), std::move(alias_analysis),
1660       std::move(hlo_live_range)));
1661 
1662   TF_RETURN_IF_ERROR(
1663       colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1664   VLOG(3) << "After coloring:";
1665   XLA_VLOG_LINES(3,
1666                  assignment->alias_analysis().dataflow_analysis().ToString());
1667 
1668   std::vector<const HloComputation*> thread_local_computations;
1669   std::vector<const HloComputation*> global_computations;
1670   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1671       module, &thread_local_computations, &global_computations));
1672 
1673   // First assign buffers for global computations. Temporary buffers for
1674   // sequential computations are collected in
1675   // 'buffers_to_assign_sequentially'.
1676   flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1677       buffers_to_assign_sequentially;
1678   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1679       global_computations,
1680       /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1681       assignment.get()));
1682   // Assign buffers with sequential ordering, if any. If all global
1683   // computations are sequential, we can run heap simulation on the whole
1684   // module, which reduces memory usage.
1685   const bool run_whole_module_heap_simulation =
1686       buffers_to_assign_sequentially.size() == global_computations.size();
1687   VLOG(2) << "Running whole module heap simulation: "
1688           << run_whole_module_heap_simulation;
1689   const int32 multiheap_size_constraint_per_heap =
1690       module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1691   VLOG(2) << "Multiheap per heap size limit: "
1692           << multiheap_size_constraint_per_heap;
1693   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1694       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1695       assignment.get()));
1696 
1697   std::vector<const HloComputation*> thread_local_computations_no_fusion;
1698   // Now assign buffers for thread-local computations. All LogicalBuffers get
1699   // their own BufferAllocation.
1700 
1701   for (auto* computation : thread_local_computations) {
1702     TF_RET_CHECK(computation != module->entry_computation());
1703     if (computation->IsFusionComputation()) {
1704       continue;
1705     }
1706     thread_local_computations_no_fusion.push_back(computation);
1707   }
1708 
1709   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1710       thread_local_computations_no_fusion,
1711       /*is_thread_local=*/true,
1712       /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1713 
1714   // Mark all buffers which may be live out of the entry computation as
1715   // "liveout".
1716   for (const HloBuffer* buffer :
1717        assignment->alias_analysis().LiveOutBuffers()) {
1718     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1719     if (assignment->HasAllocation(*buffer)) {
1720       BufferAllocation* alloc =
1721           assignment->GetMutableAssignedAllocation(*buffer);
1722       alloc->set_maybe_live_out(true);
1723       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1724     }
1725   }
1726 
1727   // Combines allocations of temporary buffers into big BufferAllocations
1728   // subject to the buffer allocation size constraint. This can only be
1729   // performed after all buffers have been assigned, and after maybe_live_out
1730   // is marked, since it is used to determine whether an allocation contains
1731   // temporary buffers or not.
1732   assignment->CombineTempAllocations();
1733 
1734   XLA_VLOG_LINES(2, assignment->ToString());
1735   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1736   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1737   VLOG(1) << "Buffer assignment done.";
1738   return std::move(assignment);
1739 }
1740 
1741 }  // namespace xla
1742