1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace xla {
33
ShapedBuffer(Shape on_device_shape,int device_ordinal)34 ShapedBuffer::ShapedBuffer(Shape on_device_shape, int device_ordinal)
35 : on_device_shape_(std::move(on_device_shape)),
36 device_ordinal_(device_ordinal),
37 buffers_(&on_device_shape_) {
38 on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
39 }
40
ShapedBuffer(Shape on_host_shape,Shape on_device_shape,int device_ordinal)41 ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
42 int device_ordinal)
43 : ShapedBuffer(on_device_shape, device_ordinal) {}
44
ShapedBuffer(ShapedBuffer && s)45 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
46 : on_host_shape_(std::move(s.on_host_shape_)),
47 on_device_shape_(std::move(s.on_device_shape_)),
48 device_ordinal_(s.device_ordinal_),
49 buffers_(std::move(s.buffers_)) {
50 // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
51 // into buffers_, we also need to update this pointer so that buffers_ doesn't
52 // point into s.
53 buffers_.replace_shape_ptr(&on_device_shape_);
54 }
55
operator =(ShapedBuffer && s)56 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
57 on_device_shape_ = std::move(s.on_device_shape_);
58 on_host_shape_ = std::move(s.on_host_shape_);
59 device_ordinal_ = s.device_ordinal_;
60 buffers_ = std::move(s.buffers_);
61 // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
62 // into buffers_, we also need to update this pointer so that buffers_ doesn't
63 // point into s.
64 buffers_.replace_shape_ptr(&on_device_shape_);
65 return *this;
66 }
67
~ShapedBuffer()68 ShapedBuffer::~ShapedBuffer() {}
69
SubShapedBuffer(const ShapeIndex & index) const70 StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
71 const ShapeIndex& index) const {
72 TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
73 ShapeUtil::TryGetSubshape(on_device_shape(), index));
74 ShapedBuffer sub_shaped_buffer(*device_sub_shape, device_ordinal_);
75 TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
76 buffers_.SubShapeTree(index));
77 sub_shaped_buffer.set_buffers(std::move(sub_buffers));
78 return std::move(sub_shaped_buffer);
79 }
80
clear()81 void ShapedBuffer::clear() {
82 for (auto& pair : buffers_) {
83 // A default constructed DeviceMemoryBase is a null pointer.
84 pair.second = se::DeviceMemoryBase();
85 }
86 }
87
ToString() const88 string ShapedBuffer::ToString() const {
89 string s =
90 absl::StrCat("ShapedBuffer(", device_ordinal(),
91 "), on-device shape=" +
92 ShapeUtil::HumanStringWithLayout(on_device_shape()),
93 ":\n");
94 ShapeUtil::ForEachSubshape(
95 on_device_shape(),
96 [this, &s](const Shape& subshape, const ShapeIndex& index) {
97 string shape_str;
98 if (subshape.IsTuple()) {
99 shape_str = "tuple";
100 } else {
101 shape_str = ShapeUtil::HumanStringWithLayout(subshape);
102 }
103 const se::DeviceMemoryBase& memory = buffer(index);
104 absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n",
105 string(index.size() * 2, ' '), memory.opaque(),
106 memory.size(), shape_str);
107 });
108 return s;
109 }
110
operator <<(std::ostream & out,const ShapedBuffer & buffer)111 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
112 out << buffer.ToString();
113 return out;
114 }
115
ScopedShapedBuffer(Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)116 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
117 se::DeviceMemoryAllocator* allocator,
118 int device_ordinal)
119 : ShapedBuffer(std::move(on_device_shape), device_ordinal),
120 allocator_(allocator) {}
121
ScopedShapedBuffer(Shape on_host_shape,Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)122 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
123 Shape on_device_shape,
124 se::DeviceMemoryAllocator* allocator,
125 int device_ordinal)
126 : ScopedShapedBuffer(std::move(on_device_shape), allocator,
127 device_ordinal) {}
128
ScopedShapedBuffer(ShapedBuffer shaped_buffer,se::DeviceMemoryAllocator * allocator)129 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
130 se::DeviceMemoryAllocator* allocator)
131 : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
132
ScopedShapedBuffer(ScopedShapedBuffer && s)133 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
134 : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
135 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
136 s.allocator_ = nullptr;
137 }
138
operator =(ScopedShapedBuffer && s)139 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
140 Deallocate();
141
142 *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
143 allocator_ = s.allocator_;
144 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
145 s.allocator_ = nullptr;
146 return *this;
147 }
148
~ScopedShapedBuffer()149 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
150
release()151 ShapedBuffer ScopedShapedBuffer::release() {
152 ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
153 buffers_ = ShapeTree<se::DeviceMemoryBase>();
154 return shaped_buffer;
155 }
156
Deallocate()157 void ScopedShapedBuffer::Deallocate() {
158 // allocator_ will be null if we were moved-from.
159 if (allocator_ == nullptr) {
160 return;
161 }
162 // Deallocate all non-null buffers. A buffer may appear in more than one spot
163 // in the shape (eg, a tuple with a repeated element) so keep track of what
164 // has been deallocated.
165 absl::flat_hash_set<void*> deallocated_ptrs;
166 for (auto& pair : buffers_) {
167 se::DeviceMemoryBase& memory_base = pair.second;
168 if (!memory_base.is_null() &&
169 deallocated_ptrs.insert(memory_base.opaque()).second) {
170 TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
171 }
172 }
173 }
174
TakeSubTree(ShapeIndexView index)175 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
176 const xla::Shape& sub_on_device_shape =
177 xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
178
179 ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(),
180 device_ordinal());
181 auto src_it = buffers().find(index);
182 auto dst_it = output.buffers().begin();
183 while (dst_it != output.buffers().end()) {
184 dst_it->second = src_it->second;
185 src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
186 ++src_it;
187 ++dst_it;
188 }
189 return output;
190 }
191
192 } // namespace xla
193