1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
18 
19 #include <ostream>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/hlo_value.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace xla {
30 
31 // A container which can hold one or more HloValues. An HLO buffer abstractly
32 // represents the allocation which HLO instructions write into and read
33 // from. Generally there is a one-to-one correspondence between HloBuffers and
34 // HloValue where each HloValue in the module is held in a unique HloBuffer. An
35 // exception is the while instruction which updates the loop state in-place. In
36 // this case, we have a single HloBuffer for each HloPosition in the loop state,
37 // but multiple HloValues. For example:
38 //
39 //   %init = ...
40 //   %while = While(%init, body, condition)
41 //
42 //  body:
43 //   %body_param = Param(0)
44 //     ...
45 //   %body_root = ...
46 //
47 //  condition:
48 //   %cond_param = Param(0)
49 //     ...
50 //
51 // For simplicity, assume that %while is array-shaped. In this case, we have a
52 // single HloBuffer which holds the following HloValues: HloValue{%init},
53 // HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
54 // HloValue{%cond_param}.
55 //
56 // HloBuffers may appear at different HloPositions in the module mirroring the
57 // same property of HloValues. For example:
58 //
59 //   %sub = Sub(...)
60 //   %add = Add(...)
61 //   %tuple = Tuple(%add, %sub)
62 //   %gte = GetTupleElement(%tuple, 0)
63 //
64 // In this case, the HloBuffer containing %add appears at the following
65 // positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and
66 // HloPosition{%gte, {}}.
67 //
68 // Different HloPositions which share the same HloBuffer indicate mandatory
69 // aliasing in the HLO module. These positions must share the same memory
70 // allocation for correctness (the backends rely on this property). This differs
71 // from incidental aliasing introduced by memory reuse in BufferAssignment where
72 // different instructions may happen to get the same allocation.
73 class HloBuffer {
74  public:
75   using Id = int64;
76 
77   // Predicate comparing HloBuffers by increasing id, useful for std::sort.
IdLessThan(const HloBuffer * a,const HloBuffer * b)78   static bool IdLessThan(const HloBuffer* a, const HloBuffer* b) {
79     return a->id() < b->id();
80   }
81 
82   // Predicate comparing HloBuffers by equal id, useful for std::unique.
IdEqual(const HloBuffer * a,const HloBuffer * b)83   static bool IdEqual(const HloBuffer* a, const HloBuffer* b) {
84     return a->id() == b->id();
85   }
86 
HloBuffer(Id id,absl::Span<const HloValue * const> values)87   HloBuffer(Id id, absl::Span<const HloValue* const> values)
88       : id_(id), values_(values.begin(), values.end()) {}
89 
90   // Return the unique identifier for this HloBuffer.
id()91   Id id() const { return id_; }
92 
93   // Return all values contained in this buffer.
values()94   const std::vector<const HloValue*>& values() const { return values_; }
95 
96   // Memory space color. Used to indicate the memory space that the hlo buffer
97   // needs to live in.
color()98   BufferValue::Color color() const {
99     // Invariant: All values in the buffer should have the same color.
100     BufferValue::Color result = values()[0]->color();
101     for (const HloValue* value : values()) {
102       DCHECK_EQ(result, value->color());
103     }
104     return result;
105   }
106 
107   // Return the unique HLO value in the buffer. CHECK fails if the buffer does
108   // not contain exactly one value.
GetUniqueValue()109   const HloValue& GetUniqueValue() const {
110     CHECK_EQ(values_.size(), 1);
111     return *values_[0];
112   }
113 
114   std::vector<HloPosition> ComputePositions() const;
115 
116   string ToString() const;
117 
118   bool operator==(const HloBuffer& other) const;
119   bool operator!=(const HloBuffer& other) const { return !(*this == other); }
120 
121  private:
122   // Unique identifier for this HloBuffer.
123   Id id_;
124 
125   // The set of values contained in this buffer. Vector contains no duplicates
126   // and is sorted stably by HloValue::Id.
127   std::vector<const HloValue*> values_;
128 };
129 
130 std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
131 
132 }  // namespace xla
133 
134 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
135