1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_buffer.h" 26 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 30 #include "tensorflow/compiler/xla/status.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/platform/macros.h" 35 36 namespace xla { 37 38 // Analysis which allocates HloBuffers to HloValues. 39 class HloAliasAnalysis { 40 public: 41 // The callgraph of the given HloModule must be flattened 42 // (xla::FlattenCallGraph) prior to running the analysis. 43 static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run( 44 HloModule* module, 45 const HloDataflowAnalysis::FusionCanShareBufferFunction& 46 fusion_can_share_buffer); 47 48 string ToString() const; 49 50 // Return the buffer containing the given value. GetBufferContainingValue(const HloValue & value)51 const HloBuffer& GetBufferContainingValue(const HloValue& value) const { 52 return *value_to_buffer_.at(&value); 53 } GetBufferContainingValue(const HloValue & value)54 HloBuffer& GetBufferContainingValue(const HloValue& value) { 55 return *value_to_buffer_.at(&value); 56 } 57 58 // Return the HloBuffer with the given ID. GetBuffer(HloBuffer::Id buffer_id)59 const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const { 60 return buffers_.at(buffer_id); 61 } GetBuffer(HloBuffer::Id buffer_id)62 HloBuffer& GetBuffer(HloBuffer::Id buffer_id) { 63 return buffers_.at(buffer_id); 64 } 65 66 // Returns the unique buffer at the given position. CHECK fails if the buffer 67 // set at that position does not contain exactly one buffer. 68 const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, 69 const ShapeIndex& index = {}) const; 70 HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, 71 const ShapeIndex& index = {}); 72 73 // Compute the set of buffers at the given instruction and index and return as 74 // a vector. This set is exactly the union of the buffers containing the 75 // HloValues at this position. 76 std::vector<const HloBuffer*> ComputeBuffersAt( 77 const HloInstruction* instruction, const ShapeIndex& index = {}) const; 78 79 // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This 80 // vector is lazily computed. Mutating operations on HloAliasAnalysis may 81 // invalidate the underlying vector requiring recomputation. buffers()82 const std::vector<HloBuffer>& buffers() const { return buffers_; } 83 84 // Returns the underlying dataflow analysis used by this alias analysis. dataflow_analysis()85 const HloDataflowAnalysis& dataflow_analysis() const { 86 return *dataflow_analysis_; 87 } 88 89 // Returns true if any index in the output of the given instruction has more 90 // than one buffer. That is, ComputeBuffersAt returns a vector with more than 91 // one element. 92 bool InstructionBuffersAreAmbiguous(const HloInstruction* instruction) const; 93 94 // Returns true if no HloBuffer appears in more than one shape index in the 95 // output of the given instruction. 96 bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const; 97 98 // Returns true if any HLO values in the module have interfering live ranges 99 // assuming the given ordering. 100 bool HasLiveRangeInterference(const HloOrdering& ordering) const; 101 102 protected: 103 explicit HloAliasAnalysis(HloModule* module); 104 105 // Verify various invariants of the alias analysis. 106 Status Verify() const; 107 108 HloModule* module_; 109 110 // The underlying dataflow analysis used by this alias analysis. 111 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_; 112 113 // A map indicating which buffer a value is contained in. 114 absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_; 115 116 // A lazily constructed vector containing all HloBuffers sorted by 117 // HloBuffer::Id. 118 std::vector<HloBuffer> buffers_; 119 }; 120 121 } // namespace xla 122 123 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ 124