1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace xla {
33 
ShapedBuffer(Shape on_device_shape,int device_ordinal)34 ShapedBuffer::ShapedBuffer(Shape on_device_shape, int device_ordinal)
35     : on_device_shape_(std::move(on_device_shape)),
36       device_ordinal_(device_ordinal),
37       buffers_(&on_device_shape_) {
38   on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
39 }
40 
ShapedBuffer(Shape on_host_shape,Shape on_device_shape,int device_ordinal)41 ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
42                            int device_ordinal)
43     : ShapedBuffer(on_device_shape, device_ordinal) {}
44 
ShapedBuffer(ShapedBuffer && s)45 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
46     : on_host_shape_(std::move(s.on_host_shape_)),
47       on_device_shape_(std::move(s.on_device_shape_)),
48       device_ordinal_(s.device_ordinal_),
49       buffers_(std::move(s.buffers_)) {
50   // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
51   // into buffers_, we also need to update this pointer so that buffers_ doesn't
52   // point into s.
53   buffers_.replace_shape_ptr(&on_device_shape_);
54 }
55 
operator =(ShapedBuffer && s)56 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
57   on_device_shape_ = std::move(s.on_device_shape_);
58   on_host_shape_ = std::move(s.on_host_shape_);
59   device_ordinal_ = s.device_ordinal_;
60   buffers_ = std::move(s.buffers_);
61   // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
62   // into buffers_, we also need to update this pointer so that buffers_ doesn't
63   // point into s.
64   buffers_.replace_shape_ptr(&on_device_shape_);
65   return *this;
66 }
67 
~ShapedBuffer()68 ShapedBuffer::~ShapedBuffer() {}
69 
SubShapedBuffer(const ShapeIndex & index) const70 StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
71     const ShapeIndex& index) const {
72   TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
73                       ShapeUtil::TryGetSubshape(on_device_shape(), index));
74   ShapedBuffer sub_shaped_buffer(*device_sub_shape, device_ordinal_);
75   TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
76                       buffers_.SubShapeTree(index));
77   sub_shaped_buffer.set_buffers(std::move(sub_buffers));
78   return std::move(sub_shaped_buffer);
79 }
80 
clear()81 void ShapedBuffer::clear() {
82   for (auto& pair : buffers_) {
83     // A default constructed DeviceMemoryBase is a null pointer.
84     pair.second = se::DeviceMemoryBase();
85   }
86 }
87 
ToString() const88 string ShapedBuffer::ToString() const {
89   string s =
90       absl::StrCat("ShapedBuffer(", device_ordinal(),
91                    "), on-device shape=" +
92                        ShapeUtil::HumanStringWithLayout(on_device_shape()),
93                    ":\n");
94   ShapeUtil::ForEachSubshape(
95       on_device_shape(),
96       [this, &s](const Shape& subshape, const ShapeIndex& index) {
97         string shape_str;
98         if (subshape.IsTuple()) {
99           shape_str = "tuple";
100         } else {
101           shape_str = ShapeUtil::HumanStringWithLayout(subshape);
102         }
103         const se::DeviceMemoryBase& memory = buffer(index);
104         absl::StrAppendFormat(&s, "  %s%p (%d bytes) : %s\n",
105                               string(index.size() * 2, ' '), memory.opaque(),
106                               memory.size(), shape_str);
107       });
108   return s;
109 }
110 
operator <<(std::ostream & out,const ShapedBuffer & buffer)111 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
112   out << buffer.ToString();
113   return out;
114 }
115 
ScopedShapedBuffer(Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)116 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
117                                        se::DeviceMemoryAllocator* allocator,
118                                        int device_ordinal)
119     : ShapedBuffer(std::move(on_device_shape), device_ordinal),
120       allocator_(allocator) {}
121 
ScopedShapedBuffer(Shape on_host_shape,Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)122 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
123                                        Shape on_device_shape,
124                                        se::DeviceMemoryAllocator* allocator,
125                                        int device_ordinal)
126     : ScopedShapedBuffer(std::move(on_device_shape), allocator,
127                          device_ordinal) {}
128 
ScopedShapedBuffer(ShapedBuffer shaped_buffer,se::DeviceMemoryAllocator * allocator)129 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
130                                        se::DeviceMemoryAllocator* allocator)
131     : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
132 
ScopedShapedBuffer(ScopedShapedBuffer && s)133 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
134     : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
135   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
136   s.allocator_ = nullptr;
137 }
138 
operator =(ScopedShapedBuffer && s)139 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
140   Deallocate();
141 
142   *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
143   allocator_ = s.allocator_;
144   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
145   s.allocator_ = nullptr;
146   return *this;
147 }
148 
~ScopedShapedBuffer()149 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
150 
release()151 ShapedBuffer ScopedShapedBuffer::release() {
152   ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
153   buffers_ = ShapeTree<se::DeviceMemoryBase>();
154   return shaped_buffer;
155 }
156 
Deallocate()157 void ScopedShapedBuffer::Deallocate() {
158   // allocator_ will be null if we were moved-from.
159   if (allocator_ == nullptr) {
160     return;
161   }
162   // Deallocate all non-null buffers. A buffer may appear in more than one spot
163   // in the shape (eg, a tuple with a repeated element) so keep track of what
164   // has been deallocated.
165   absl::flat_hash_set<void*> deallocated_ptrs;
166   for (auto& pair : buffers_) {
167     se::DeviceMemoryBase& memory_base = pair.second;
168     if (!memory_base.is_null() &&
169         deallocated_ptrs.insert(memory_base.opaque()).second) {
170       TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
171     }
172   }
173 }
174 
TakeSubTree(ShapeIndexView index)175 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
176   const xla::Shape& sub_on_device_shape =
177       xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
178 
179   ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(),
180                             device_ordinal());
181   auto src_it = buffers().find(index);
182   auto dst_it = output.buffers().begin();
183   while (dst_it != output.buffers().end()) {
184     dst_it->second = src_it->second;
185     src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
186     ++src_it;
187     ++dst_it;
188   }
189   return output;
190 }
191 
192 }  // namespace xla
193