1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
18 
19 #include <ostream>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/hlo_value.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace xla {
30 
31 // A container which can hold one or more HloValues. An HLO buffer abstractly
32 // represents the allocation which HLO instructions write into and read
33 // from. Generally there is a one-to-one correspondence between HloBuffers and
34 // HloValue where each HloValue in the module is held in a unique HloBuffer. An
35 // exception is the while instruction which updates the loop state in-place. In
36 // this case, we have a single HloBuffer for each HloPosition in the loop state,
37 // but multiple HloValues. For example:
38 //
39 //   %init = ...
40 //   %while = While(%init, body, condition)
41 //
42 //  body:
43 //   %body_param = Param(0)
44 //     ...
45 //   %body_root = ...
46 //
47 //  condition:
48 //   %cond_param = Param(0)
49 //     ...
50 //
51 // For simplicity, assume that %while is array-shaped. In this case, we have a
52 // single HloBuffer which holds the following HloValues: HloValue{%init},
53 // HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
54 // HloValue{%cond_param}.
55 //
56 // HloBuffers may appear at different HloPositions in the module mirroring the
57 // same propery of HloValues. For example:
58 //
59 //   %sub = Sub(...)
60 //   %add = Add(...)
61 //   %tuple = Tuple(%add, %sub)
62 //   %gte = GetTupleElement(%tuple, 0)
63 //
64 // In this case, the HloBuffer containing %add appears at the following
65 // positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and
66 // HloPosition{%gte, {}}.
67 //
68 // Different HloPositions which share the same HloBuffer indicate mandatory
69 // aliasing in the HLO module. These positions must share the same memory
70 // allocation for correctness (the backends rely on this property). This differs
71 // from incidental aliasing introduced by memory reuse in BufferAssignment where
72 // different instructions may happen to get the same allocation.
73 class HloBuffer {
74  public:
75   using Id = int64;
76 
77   // Predicate comparing HloBuffers by increasing id, useful for std::sort.
IdLessThan(const HloBuffer * a,const HloBuffer * b)78   static bool IdLessThan(const HloBuffer* a, const HloBuffer* b) {
79     return a->id() < b->id();
80   }
81 
82   // Predicate comparing HloBuffers by equal id, useful for std::unique.
IdEqual(const HloBuffer * a,const HloBuffer * b)83   static bool IdEqual(const HloBuffer* a, const HloBuffer* b) {
84     return a->id() == b->id();
85   }
86 
HloBuffer(Id id,absl::Span<const HloValue * const> values)87   HloBuffer(Id id, absl::Span<const HloValue* const> values)
88       : id_(id), values_(values.begin(), values.end()) {}
89 
90   // Return the unique identifier for this HloBuffer.
id()91   Id id() const { return id_; }
92 
93   // Return all values contained in this buffer.
values()94   const std::vector<const HloValue*>& values() const { return values_; }
95 
96   // Return the unique HLO value in the buffer. CHECK fails if the buffer does
97   // not contain exactly one value.
GetUniqueValue()98   const HloValue& GetUniqueValue() const {
99     CHECK_EQ(values_.size(), 1);
100     return *values_[0];
101   }
102 
103   std::vector<HloPosition> ComputePositions() const;
104 
105   string ToString() const;
106 
107   bool operator==(const HloBuffer& other) const;
108   bool operator!=(const HloBuffer& other) const { return !(*this == other); }
109 
110  private:
111   // Unique identifier for this HloBuffer.
112   const Id id_;
113 
114   // The set of values contained in this buffer. Vector contains no duplicates
115   // and is sorted stably by HloValue::Id.
116   const std::vector<const HloValue*> values_;
117 };
118 
119 std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
120 
121 }  // namespace xla
122 
123 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
124