1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 18 19 #include <ostream> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/service/hlo_value.h" 24 #include "tensorflow/compiler/xla/shape_tree.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace xla { 30 31 // A container which can hold one or more HloValues. An HLO buffer abstractly 32 // represents the allocation which HLO instructions write into and read 33 // from. Generally there is a one-to-one correspondence between HloBuffers and 34 // HloValue where each HloValue in the module is held in a unique HloBuffer. An 35 // exception is the while instruction which updates the loop state in-place. In 36 // this case, we have a single HloBuffer for each HloPosition in the loop state, 37 // but multiple HloValues. For example: 38 // 39 // %init = ... 40 // %while = While(%init, body, condition) 41 // 42 // body: 43 // %body_param = Param(0) 44 // ... 45 // %body_root = ... 46 // 47 // condition: 48 // %cond_param = Param(0) 49 // ... 50 // 51 // For simplicity, assume that %while is array-shaped. In this case, we have a 52 // single HloBuffer which holds the following HloValues: HloValue{%init}, 53 // HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and 54 // HloValue{%cond_param}. 55 // 56 // HloBuffers may appear at different HloPositions in the module mirroring the 57 // same propery of HloValues. For example: 58 // 59 // %sub = Sub(...) 60 // %add = Add(...) 61 // %tuple = Tuple(%add, %sub) 62 // %gte = GetTupleElement(%tuple, 0) 63 // 64 // In this case, the HloBuffer containing %add appears at the following 65 // positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and 66 // HloPosition{%gte, {}}. 67 // 68 // Different HloPositions which share the same HloBuffer indicate mandatory 69 // aliasing in the HLO module. These positions must share the same memory 70 // allocation for correctness (the backends rely on this property). This differs 71 // from incidental aliasing introduced by memory reuse in BufferAssignment where 72 // different instructions may happen to get the same allocation. 73 class HloBuffer { 74 public: 75 using Id = int64; 76 77 // Predicate comparing HloBuffers by increasing id, useful for std::sort. IdLessThan(const HloBuffer * a,const HloBuffer * b)78 static bool IdLessThan(const HloBuffer* a, const HloBuffer* b) { 79 return a->id() < b->id(); 80 } 81 82 // Predicate comparing HloBuffers by equal id, useful for std::unique. IdEqual(const HloBuffer * a,const HloBuffer * b)83 static bool IdEqual(const HloBuffer* a, const HloBuffer* b) { 84 return a->id() == b->id(); 85 } 86 HloBuffer(Id id,absl::Span<const HloValue * const> values)87 HloBuffer(Id id, absl::Span<const HloValue* const> values) 88 : id_(id), values_(values.begin(), values.end()) {} 89 90 // Return the unique identifier for this HloBuffer. id()91 Id id() const { return id_; } 92 93 // Return all values contained in this buffer. values()94 const std::vector<const HloValue*>& values() const { return values_; } 95 96 // Return the unique HLO value in the buffer. CHECK fails if the buffer does 97 // not contain exactly one value. GetUniqueValue()98 const HloValue& GetUniqueValue() const { 99 CHECK_EQ(values_.size(), 1); 100 return *values_[0]; 101 } 102 103 std::vector<HloPosition> ComputePositions() const; 104 105 string ToString() const; 106 107 bool operator==(const HloBuffer& other) const; 108 bool operator!=(const HloBuffer& other) const { return !(*this == other); } 109 110 private: 111 // Unique identifier for this HloBuffer. 112 const Id id_; 113 114 // The set of values contained in this buffer. Vector contains no duplicates 115 // and is sorted stably by HloValue::Id. 116 const std::vector<const HloValue*> values_; 117 }; 118 119 std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); 120 121 } // namespace xla 122 123 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 124