1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
26 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/macros.h"
35 
36 namespace xla {
37 
38 // Analysis which allocates HloBuffers to HloValues.
39 class HloAliasAnalysis {
40  public:
41   // The callgraph of the given HloModule must be flattened
42   // (xla::FlattenCallGraph) prior to running the analysis.
43   static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(
44       const HloModule* module,
45       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr);
46 
47   string ToString() const;
48 
49   // Return the buffer containing the given value.
GetBufferContainingValue(const HloValue & value)50   const HloBuffer& GetBufferContainingValue(const HloValue& value) const {
51     return *value_to_buffer_.at(&value);
52   }
GetBufferContainingValue(const HloValue & value)53   HloBuffer& GetBufferContainingValue(const HloValue& value) {
54     return *value_to_buffer_.at(&value);
55   }
56 
57   // Return the HloBuffer with the given ID.
GetBuffer(HloBuffer::Id buffer_id)58   const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const {
59     return buffers_.at(buffer_id);
60   }
GetBuffer(HloBuffer::Id buffer_id)61   HloBuffer& GetBuffer(HloBuffer::Id buffer_id) {
62     return buffers_.at(buffer_id);
63   }
64 
65   // Returns the unique buffer at the given position. CHECK fails if the buffer
66   // set at that position does not contain exactly one buffer.
67   const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
68                                      const ShapeIndex& index = {}) const;
69   HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
70                                const ShapeIndex& index = {});
71 
72   // Compute the set of buffers at the given instruction and index and return as
73   // a vector. This set is exactly the union of the buffers containing the
74   // HloValues at this position.
75   std::vector<const HloBuffer*> ComputeBuffersAt(
76       const HloInstruction* instruction, const ShapeIndex& index = {}) const;
77 
78   // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
79   // vector is lazily computed. Mutating operations on HloAliasAnalysis may
80   // invalidate the underlying vector requiring recomputation.
buffers()81   const std::vector<HloBuffer>& buffers() const { return buffers_; }
82 
83   // Returns the underlying dataflow analysis used by this alias analysis.
dataflow_analysis()84   HloDataflowAnalysis& dataflow_analysis() const { return *dataflow_analysis_; }
85 
86   // Returns true if any index in the output of the given instruction has more
87   // than one buffer. That is, ComputeBuffersAt returns a vector with more than
88   // one element.
89   bool InstructionBuffersAreAmbiguous(const HloInstruction* instruction) const;
90 
91   // Returns true if no HloBuffer appears in more than one shape index in the
92   // output of the given instruction.
93   bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
94 
95   // Merge buffer `from` into buffer `to`. Caller has to make sure no
96   // interference will be introduced after merging. This rebuilds internal data
97   // structure, and invalidates references to all existing buffers.
98   void MergeBuffers(const HloBuffer& to, const HloBuffer& from);
99 
100   // Returns true if any HLO values in the module have interfering live ranges
101   // assuming the given ordering.
102   bool HasLiveRangeInterference(const HloOrdering& ordering) const;
103 
104   // Returns true if a buffer lives out of the module.
BufferLivesOut(const HloBuffer & buffer)105   bool BufferLivesOut(const HloBuffer& buffer) const {
106     return live_out_buffers_.count(&buffer);
107   }
108 
109   // Returns true if a hlo value lives out of the module.
ValueLivesOut(const HloValue & value)110   bool ValueLivesOut(const HloValue& value) const {
111     return live_out_buffers_.count(&GetBufferContainingValue(value));
112   }
113 
LiveOutBuffers()114   std::vector<const HloBuffer*> LiveOutBuffers() const {
115     std::vector<const HloBuffer*> results(live_out_buffers_.begin(),
116                                           live_out_buffers_.end());
117     absl::c_sort(results, [](const HloBuffer* a, const HloBuffer* b) {
118       return a->id() < b->id();
119     });
120     return results;
121   }
122 
123  protected:
124   explicit HloAliasAnalysis(const HloModule* module);
125 
126   // Verify various invariants of the alias analysis.
127   Status Verify() const;
128 
129   const HloModule* module_;
130 
131   // A set of buffers that live out the module.
132   absl::flat_hash_set<const HloBuffer*> live_out_buffers_;
133 
134   // The underlying dataflow analysis used by this alias analysis.
135   std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
136 
137   // A map indicating which buffer a value is contained in.
138   absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_;
139 
140   // A lazily constructed vector containing all HloBuffers sorted by
141   // HloBuffer::Id.
142   std::vector<HloBuffer> buffers_;
143 };
144 
145 }  // namespace xla
146 
147 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
148