1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <iosfwd>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
29 #include "tensorflow/compiler/xla/service/heap_simulator.h"
30 #include "tensorflow/compiler/xla/service/hlo.pb.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/logical_buffer.h"
35 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace xla {
43 
44 // Walk the call graph of the HLO module and place each computation into either
45 // thread_local_computations or global_computations depending upon whether the
46 // computation requires thread-local allocations or global allocations. The
47 // elements in thread_local_computations and global_computations are in post
48 // order (if computation A has an instruction which calls computation B, then A
49 // will appear after B in the vector).
50 Status GatherComputationsByAllocationType(
51     const HloModule* module,
52     std::vector<const HloComputation*>* thread_local_computations,
53     std::vector<const HloComputation*>* global_computations);
54 
55 // This class abstracts an allocation of contiguous memory which can hold the
56 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
57 // of the allocation, represented by a Slice. A single BufferAllocation may hold
58 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
59 // single BufferAllocation may also hold LogicalBuffers with overlapping
60 // liveness, which must have disjoint Slices.
61 //
62 // The abstraction includes information required by the backends for allocation,
63 // use, and deallocation of the buffer. This includes the LogicalBuffers which
64 // are held in this allocation through the execution of the computation.
65 class BufferAllocation {
66  public:
67   // Holds a unique identifier for each allocation. Values are assigned
68   // contiguously and can be used as array indexes.
69   using Index = int64;
70 
BufferAllocation(Index index,int64 size,LogicalBuffer::Color color)71   BufferAllocation(Index index, int64 size, LogicalBuffer::Color color)
72       : index_(index), size_(size), color_(color) {}
~BufferAllocation()73   ~BufferAllocation() {}
74 
75   // Returns the index of this allocation.
index()76   Index index() const { return index_; }
77 
78   // Whether this allocation is used in a parallel calling context such as
79   // inside of a map or reduce computation. Such allocations need to be thread
80   // local.
is_thread_local()81   bool is_thread_local() const { return is_thread_local_; }
set_is_thread_local(bool is_thread_local)82   void set_is_thread_local(bool is_thread_local) {
83     is_thread_local_ = is_thread_local;
84   }
85 
86   // Whether this allocation can be used by more than one logical buffer.
is_reusable()87   bool is_reusable() const {
88     // We do not reuse thread-local buffers for now, because they are
89     // dynamically allocated and their lifetimes are hard to compute.
90     //
91     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
92     // assumes longer buffer liveness than indicated by the analysis.
93     return !is_thread_local() && !is_tuple();
94   }
95 
96   // Whether this allocation is readonly i.e. backed by memory we cannot write
97   // to.
is_readonly()98   bool is_readonly() const {
99     // Entry parameters are generally readonly, except when they are aliased
100     // with any output.
101     return (is_entry_computation_parameter() &&
102             !is_parameter_aliased_with_output_) ||
103            is_constant();
104   }
105 
is_tuple()106   bool is_tuple() const { return is_tuple_; }
set_is_tuple(bool is_tuple)107   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
108 
109   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
110   // computation. These buffers have lifetimes which may be longer than the
111   // XLA computation.
is_entry_computation_parameter()112   bool is_entry_computation_parameter() const {
113     return is_entry_computation_parameter_;
114   }
115 
116   // Whether this allocation holds a constant.  On the CPU and GPU backends
117   // constant allocations are not allocated dynamically, instead we resolve
118   // references to these buffer allocations to a global in the readonly section
119   // of the binary.
is_constant()120   bool is_constant() const { return is_constant_; }
121 
122   // If this allocation holds a Buffer from a parameter of the entry
123   // computation, this methods returns the parameter number. CHECKs otherwise.
parameter_number()124   int64 parameter_number() const {
125     CHECK(is_entry_computation_parameter_);
126     return parameter_number_;
127   }
128 
129   // If this allocation is for a parameter of the entry computation, this
130   // function returns which subshape of the parameter the allocation is for.
param_shape_index()131   const ShapeIndex& param_shape_index() const {
132     CHECK(is_entry_computation_parameter_);
133     return param_shape_index_;
134   }
135 
136   // Returns whether this allocation is assigned a LogicalBuffer which may
137   // be live out of the entry computation.
maybe_live_out()138   bool maybe_live_out() const { return maybe_live_out_; }
139 
140   // Returns the size of the allocation. Necessarily this must be at least as
141   // large as any LogicalBuffer assigned to this allocation.
size()142   int64 size() const { return size_; }
143 
144   // Returns the color of the allocation. Only logical buffers with a matching
145   // color can reside in this allocation.
color()146   LogicalBuffer::Color color() const { return color_; }
147 
148   struct OffsetSize {
149     int64 offset = 0;
150     int64 size = 0;
151   };
152 
153   // Access to the logical buffers assigned to this allocation, and their
154   // associated logical offsets and sizes.
155   const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>&
assigned_buffers()156   assigned_buffers() const {
157     return assigned_buffers_;
158   }
159 
160   // A Slice represents a contiguous portion of a memory allocation. It is used
161   // to identify the memory range that a LogicalBuffer corresponds to.
162   class Slice {
163    public:
Slice()164     Slice() {}
Slice(const BufferAllocation * allocation,int64 offset,int64 size)165     Slice(const BufferAllocation* allocation, int64 offset, int64 size)
166         : allocation_(allocation), offset_(offset), size_(size) {}
167 
allocation()168     const BufferAllocation* allocation() const { return allocation_; }
index()169     Index index() const { return allocation_->index(); }
offset()170     int64 offset() const { return offset_; }
size()171     int64 size() const { return size_; }
172 
173     bool operator==(const Slice& other) const {
174       return index() == other.index() && offset_ == other.offset_ &&
175              size_ == other.size_;
176     }
177     bool operator!=(const Slice& other) const { return !(*this == other); }
178     bool operator<(const Slice& other) const {
179       if (index() != other.index()) return index() < other.index();
180       if (offset_ != other.offset_) return offset_ < other.offset_;
181       return size_ < other.size_;
182     }
183 
184     // Returns true iff this slice's memory range has a non-empty intersection
185     // with the other slice's memory range.
OverlapsWith(const Slice & other)186     bool OverlapsWith(const Slice& other) const {
187       const int64 end = offset_ + size_;
188       const int64 other_end = other.offset_ + other.size_;
189       return index() == other.index() && offset_ < other_end &&
190              end > other.offset_;
191     }
192 
193     template <typename H>
AbslHashValue(H h,const Slice & s)194     friend H AbslHashValue(H h, const Slice& s) {
195       return H::combine(std::move(h), s.index(), s.offset(), s.size());
196     }
197 
198     string ToString() const;
199 
200    private:
201     const BufferAllocation* allocation_ = nullptr;
202     int64 offset_ = 0;
203     int64 size_ = 0;
204   };
205 
206   // GetSlice returns the Slice of contiguous memory that holds the value
207   // described by the given 'buffer'.
208   // REQUIRES: 'buffer' must be assigned to this allocation.
209   Slice GetSlice(const LogicalBuffer& buffer) const;
210 
211   string ToString() const;
212   BufferAllocationProto ToProto() const;
213 
214   // Whether the buffer is a parameter to or live out of the entry computation.
IsInputOrOutput()215   bool IsInputOrOutput() const {
216     return is_entry_computation_parameter() || maybe_live_out();
217   }
218 
219   // Whether the buffer is a temporary buffer allocated before
220   // Executable::ExecuteOnStream.
IsPreallocatedTempBuffer()221   bool IsPreallocatedTempBuffer() const {
222     // Parameters do not need temporary buffers.
223     return !is_entry_computation_parameter() &&
224            // LogicalBuffers that maybe pointed to by the output should live out
225            // of the computation.
226            !maybe_live_out() &&
227            // Thread-local buffers are allocated using `alloca`s.
228            !is_thread_local() &&
229            // Constant buffers are allocated as global values.
230            !is_constant();
231   }
232 
233   // Add a heap trace which was used to assign slices to logical buffers in this
234   // allocation. A single BufferAllocation may include multiple heap traces
235   // in the case of the temporary block where there is a heap trace per
236   // computation.
AddHeapTrace(const HeapSimulatorTrace & heap_trace)237   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
238     heap_traces_.push_back(heap_trace);
239   }
240 
241   // Return the set of heap traces used to assign slices to logical buffers in
242   // this allocation.
HeapTraces()243   const std::vector<HeapSimulatorTrace> HeapTraces() const {
244     return heap_traces_;
245   }
246 
247   // Returns the LogicalBuffers which are live at the point of peak memory usage
248   // for this allocation. The point of peak memory usage is the point at which
249   // the total size of all live logical buffers is maximal. If peak memory is
250   // reached at multiple points, the set of logical buffers live at the earliest
251   // maximal point is returned. The vector is stabily sorted by
252   // LogicalBuffer::Index.
PeakMemoryLogicalBuffers()253   const std::vector<const LogicalBuffer*>& PeakMemoryLogicalBuffers() const {
254     return peak_buffers_;
255   }
256 
257   // Get the number of bytes lost to fragmentation. This is equal to the
258   // difference between the size of the allocation and the size of the maximal
259   // live set.
fragmentation_bytes()260   int64 fragmentation_bytes() const { return fragmentation_bytes_; }
261 
262   bool operator==(const BufferAllocation& other) const {
263     return index_ == other.index_;
264   }
265   bool operator!=(const BufferAllocation& other) const {
266     return !(*this == other);
267   }
268   bool operator<(const BufferAllocation& other) const {
269     return index() < other.index();
270   }
271 
272  private:
273   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
274   friend class BufferAssigner;
275   friend class BufferAssignment;
276 
277   // Adds a LogicalBuffer to the set assigned to this buffer.
278   void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size);
279 
set_entry_computation_parameter(int64 parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)280   void set_entry_computation_parameter(int64 parameter_number,
281                                        ShapeIndex param_shape_index,
282                                        bool parameter_aliased_with_output) {
283     is_entry_computation_parameter_ = true;
284     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
285     parameter_number_ = parameter_number;
286     param_shape_index_ = std::move(param_shape_index);
287   }
288 
set_constant(bool is_constant)289   void set_constant(bool is_constant) { is_constant_ = is_constant; }
set_maybe_live_out(bool value)290   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
set_index(Index index)291   void set_index(Index index) { index_ = index; }
set_size(int64 size)292   void set_size(int64 size) { size_ = size; }
293 
294   // The index of the allocation in the BufferAssignment.
295   Index index_;
296 
297   // Size of the allocation in bytes.
298   int64 size_;
299 
300   // Whether this buffer needs to be thread-local.
301   bool is_thread_local_ = false;
302 
303   // Whether this buffer holds a tuple.
304   bool is_tuple_ = false;
305 
306   // Color of the allocation.
307   LogicalBuffer::Color color_;
308 
309   // Whether this allocation holds an entry computation parameter. Entry
310   // computation parameters are special be cause they have lifetimes which may
311   // outlast the computation.
312   bool is_entry_computation_parameter_ = false;
313 
314   // Whether this entry computation parameter is aliased with output.
315   bool is_parameter_aliased_with_output_ = false;
316 
317   // If this allocation holds an entry computation parameter, this field
318   // indicates the index (starting from 0) of the parameter.
319   int64 parameter_number_ = 0;
320 
321   // If this buffer is for an entry computation parameter, which subshape of the
322   // parameter is it for?
323   ShapeIndex param_shape_index_;
324 
325   // Whether the allocation contains a LogicalBuffer which may be live-out of
326   // the entry computation. Note that this flag is conservatively computed by
327   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
328   // might not actually escape.
329   bool maybe_live_out_ = false;
330 
331   // See comment on the is_constant() accessor.
332   bool is_constant_ = false;
333 
334   // Mapping from the set of buffers assigned to this allocation to their
335   // logical offsets and sizes.
336   absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_;
337 
338   int64 fragmentation_bytes_ = 0;
339   std::vector<HeapSimulatorTrace> heap_traces_;
340 
341   // Set of buffers live at the point of peak memory usage for this allocation.
342   std::vector<const LogicalBuffer*> peak_buffers_;
343 };
344 
345 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
346 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
347 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
348 
349 // This class encapsulates an assignment of the LogicalBuffers in an XLA
350 // module to a set of BufferAllocations.
351 class BufferAssignment {
352  public:
353   // Returns the vector containing all buffer allocations in this assignment.
Allocations()354   const std::vector<BufferAllocation>& Allocations() const {
355     return allocations_;
356   }
357 
358   // Returns the total size allocation holding all temporary buffers.
temp_allocation_total_size()359   int64 temp_allocation_total_size() const {
360     return temp_allocation_total_size_;
361   }
362 
363   // Returns whether the given buffer has been assigned an allocation.
364   bool HasAllocation(const LogicalBuffer& buffer) const;
365 
366   // Returns the allocation that a particular LogicalBuffer has been assigned
367   // to. CHECKs if buffer has not been assigned an allocation.
368   const BufferAllocation& GetAssignedAllocation(
369       const LogicalBuffer& buffer) const;
370 
371   // Returns the allocation with the given index. CHECKs if no allocation exists
372   // with the given index.
373   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
374 
375   // Returns the allocation with the given instruction and shape index. nullptr
376   // if no allocation exists.
377   const BufferAllocation* GetInstructionAllocation(
378       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
379 
380   // Builds and returns a vector containing the slices which might contain the
381   // subvalue at the given index of given instruction.
382   std::set<BufferAllocation::Slice> GetAllSlices(
383       const HloInstruction* instruction, const ShapeIndex& index) const;
384 
385   // Convenience function which returns whether the buffer of the
386   // instruction at the given index is assigned an allocation.
387   bool HasAllocationAt(const HloInstruction* instruction,
388                        const ShapeIndex& index) const;
389 
390   // Convenience function which returns whether the top-level buffer of the
391   // instruction (index == {}) is assigned an allocation.
392   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
393 
394   // Convenience function which returns the unique slice containing the buffer
395   // at the given index of the given instruction. If a slice is not assigned or
396   // the slice cannot be determined at compile time then an error is returned.
397   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
398       const HloInstruction* instruction, const ShapeIndex& index) const;
399   // Like GetUniqueSlice but fixes the index to the top-level of the shape
400   // (index = {}).
401   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
402       const HloInstruction* instruction) const;
403   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
404   // entry computation of the HLO module (ie, the result of the XLA
405   // computation).
406   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
407 
408   // Returns the set LogicalBuffers which may be the source of the value at the
409   // given index and instruction.
GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)410   const PointsToSet::BufferList& GetSourceBuffers(
411       const HloInstruction* instruction, const ShapeIndex& index) const {
412     return GetPointsToSet(instruction).element(index);
413   }
414 
415   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
416   // share the same BufferAllocation::Slice.
417   // Returns false otherwise.
418   // REQUIRES: BufferAssignment assigned allocations to both instructions.
419   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
420                           const ShapeIndex& shape_index_a,
421                           const HloInstruction* hlo_b,
422                           const ShapeIndex& shape_index_b) const;
423 
424   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
425   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)426   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
427                            const HloInstruction* hlo_b) const {
428     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
429   }
430 
431   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
432   // their top-level and each of their nested shape indices, and if hlo_a's
433   // buffers are all different from hlo_b's buffers.
434   bool HaveDisjointSlices(const HloInstruction* hlo_a,
435                           const HloInstruction* hlo_b) const;
436 
437   // Returns the underlying points-to analysis used for this assignment.
points_to_analysis()438   const TuplePointsToAnalysis& points_to_analysis() const {
439     return liveness_->points_to_analysis();
440   }
441 
442   // Returns the BufferLiveness object used to construct this assignment.
liveness()443   const BufferLiveness& liveness() const { return *liveness_; }
444 
445   string ToString() const;
446   BufferAssignmentProto ToProto() const;
447 
448   // Statistics for the assignment.  Values initialized to -1 are not always
449   // collected; fragmentation is only collected for instructions that have a
450   // sequential total ordering.
451   struct Stats {
452     int64 parameter_allocation_count = 0;
453     int64 parameter_allocation_bytes = 0;
454     int64 constant_allocation_count = 0;
455     int64 constant_allocation_bytes = 0;
456     int64 maybe_live_out_allocation_count = 0;
457     int64 maybe_live_out_allocation_bytes = 0;
458     int64 preallocated_temp_allocation_count = 0;
459     int64 preallocated_temp_allocation_bytes = 0;
460     int64 preallocated_temp_fragmentation_bytes = -1;
461     int64 total_allocation_count = 0;
462     int64 total_allocation_bytes = 0;
463     int64 total_fragmentation_bytes = -1;
464 
465     string ToString() const;
466   };
GetStats()467   const Stats& GetStats() const { return stats_; }
468 
469  private:
470   // Only BufferAssigner can build or modify BufferAssignments.
471   friend class BufferAssigner;
472 
BufferAssignment(const HloModule * module,std::unique_ptr<BufferLiveness> liveness,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment)473   BufferAssignment(const HloModule* module,
474                    std::unique_ptr<BufferLiveness> liveness,
475                    LogicalBuffer::SizeFunction buffer_size,
476                    LogicalBuffer::AlignmentFunction color_alignment)
477       : module_(module),
478         liveness_(std::move(liveness)),
479         buffer_size_(std::move(buffer_size)),
480         color_alignment_(std::move(color_alignment)) {}
481 
482   // Creates and returns a new BufferAllocation, with no assigned
483   // LogicalBuffers. Ownership is maintained internally.
484   BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color);
485 
486   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
487   // creating an allocation containing a single LogicalBuffer.
488   BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size);
489 
490   // Adds a LogicalBuffer to the set assigned to the given allocation.
491   void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
492                      int64 offset, int64 size);
493 
494   // Returns the HloModule used to construct this assignment.
module()495   const HloModule& module() const { return *module_; }
496 
497   // Convenience function which returns the PointsToSet for the given
498   // instruction. Extracted from the liveness object.
499   const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const;
500 
501   // Mutable accessors for allocations.
502   BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer);
503   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
504 
505   // Combines allocations of temporary buffers into one big BufferAllocation.
506   void CombineTempAllocations();
507 
508   // Computes stats for the assignment, to be retrieved by GetStats.
509   Status ComputeSummaryStats();
510 
511   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
512   std::vector<BufferAllocation> allocations_;
513 
514   // The total size of all temporary buffers.
515   int64 temp_allocation_total_size_ = 0;
516 
517   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
518   absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index>
519       allocation_index_for_buffer_;
520 
521   const HloModule* module_;
522   const std::unique_ptr<BufferLiveness> liveness_;
523 
524   // Function which returns the buffer size for a given logical buffer (shape).
525   LogicalBuffer::SizeFunction buffer_size_;
526 
527   // Function which returns the alignment for a given logical buffer color.
528   LogicalBuffer::AlignmentFunction color_alignment_;
529 
530   Stats stats_;
531 
532   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
533 };
534 
535 // A class which constructs a buffer assignment.
536 class BufferAssigner {
537  public:
538   // Returns false if a buffer cannot be assigned to given allocation.
539   using ReuseAllocationFunction = std::function<bool(
540       const BufferAssignment& assignment, const BufferAllocation& alloc,
541       const LogicalBuffer& buffer)>;
542 
543   // Build and return a BufferAssignment for the given module. The given
544   // HloOrdering is used to determine buffer liveness. buffer_size and
545   // color_alignment are functions which returns the size and alignment of a
546   // LogicalBuffer.  allow_input_output_aliasing specifies whether input buffer
547   // are allowed to be reused as outbut buffers by the client code.
548   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
549       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
550       LogicalBuffer::SizeFunction buffer_size,
551       LogicalBuffer::AlignmentFunction color_alignment,
552       bool allow_input_output_aliasing = false,
553       bool allocate_buffers_for_constants = false,
554       BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(),
555       ReuseAllocationFunction reuse_checker = nullptr);
556 
557  private:
BufferAssigner(bool allocate_buffers_for_constants,BufferLiveness::Colorer colorer,ReuseAllocationFunction reuse_checker)558   BufferAssigner(bool allocate_buffers_for_constants,
559                  BufferLiveness::Colorer colorer,
560                  ReuseAllocationFunction reuse_checker)
561       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
562         colorer_(colorer),
563         reuse_checker_(reuse_checker) {}
564   virtual ~BufferAssigner() = default;
565 
566   // Create a buffer assignment.
567   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
568       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
569       LogicalBuffer::SizeFunction buffer_size,
570       LogicalBuffer::AlignmentFunction color_alignment);
571 
572   // Assigns buffers to the instructions in the given computation. "assignment"
573   // is modified to reflect the new buffer assignments. If is_thread_local is
574   // true, then all assigned buffers have the is_thread_local flag set to
575   // true.
576   Status AssignBuffersForComputation(
577       const HloComputation* computation, bool is_thread_local,
578       const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
579       const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
580       absl::flat_hash_map<const HloComputation*,
581                           absl::flat_hash_set<const LogicalBuffer*>>*
582           buffers_to_assign_sequentially,
583       BufferAssignment* assignment);
584 
585   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
586   // the HLO instructions will be executed in the sequential order given by
587   // assignment->liveness().hlo_ordering().SequentialOrder. If
588   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
589   // assuming all global computations are sequentially ordered.
590   Status AssignBuffersWithSequentialOrdering(
591       const absl::flat_hash_map<const HloComputation*,
592                                 absl::flat_hash_set<const LogicalBuffer*>>&
593           buffers_to_assign_sequentially,
594       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
595 
596   // Uses the results of the heap simulator to create a single allocation, with
597   // LogicalBuffers packed to specific offsets.
598   void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
599                                       BufferAssignment* assignment,
600                                       LogicalBuffer::Color color);
601 
602   // Tries to assign the given instruction to the given buffer. Returns if the
603   // assignment was successful.
604   bool MaybeAssignBuffer(BufferAllocation* allocation,
605                          const LogicalBuffer& buffer,
606                          BufferAssignment* assignment);
607 
608   // Colocated buffers are logical buffers from different computations which
609   // alias. Explicitly handling these colocated buffers is necessary because
610   // points-to analysis is computation level scope and does not recognize
611   // aliasing across computations (b/32491382).
612   using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
613 
614   // Returns a vector of ColocatedBufferSet objects, where each
615   // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
616   // which should be colocated in the same buffer allocation.
617   void BuildColocatedBufferSets(
618       const HloModule* module, const BufferLiveness& buffer_liveness,
619       const LogicalBuffer::SizeFunction& buffer_size,
620       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
621 
622   // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
623   // same set to the same buffer allocation in 'assignment'.
624   void AssignColocatedBufferSets(
625       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
626       BufferAssignment* assignment,
627       absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
628       absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
629 
630   // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
631   // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
632   void AddSetToColocatedBufferSets(
633       const std::vector<const LogicalBuffer*>& colocated_set,
634       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
635 
636   // Given a list of colocated buffer sets (each colocated buffer set represents
637   // the logical buffers that would be assigned to the same physical buffer),
638   // try to merge the sets if the buffers can be shared. Returns the merged set.
639   std::vector<ColocatedBufferSet> MergeColocatedBufferSets(
640       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
641       const BufferLiveness& buffer_liveness,
642       const LogicalBuffer::SizeFunction& buffer_size);
643 
644   // Split a set of buffers into several sets, each of which contains buffers
645   // colored with the same color.
646   absl::flat_hash_map<LogicalBuffer::Color,
647                       absl::flat_hash_set<const LogicalBuffer*>,
648                       LogicalBuffer::Color::Hasher>
649   SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
650 
651   // If true, allocate buffers for constant instructions.
652   bool allocate_buffers_for_constants_;
653 
654   // Functor used to assign colors to newly allocated logical buffers.
655   BufferLiveness::Colorer colorer_;
656 
657   // Functor to check if a buffer can reuse an allocation.
658   ReuseAllocationFunction reuse_checker_;
659 
660   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
661 };
662 
663 }  // namespace xla
664 
665 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
666