1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
24 #include "tensorflow/compiler/xla/service/transfer_manager.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 
Register(ScopedShapedBuffer shaped_buffer,const string & tag)33 StatusOr<GlobalDataHandle> AllocationTracker::Register(
34     ScopedShapedBuffer shaped_buffer, const string& tag) {
35   tensorflow::mutex_lock lock(mutex_);
36   VLOG(2) << "Register";
37   std::vector<ScopedShapedBuffer> replicated_buffers;
38   replicated_buffers.emplace_back(std::move(shaped_buffer));
39   return RegisterInternal(std::move(replicated_buffers), tag);
40 }
41 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)42 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
43     std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
44   tensorflow::mutex_lock lock(mutex_);
45   VLOG(2) << "RegisterReplicatedBuffers";
46   return RegisterInternal(std::move(replicated_buffers), tag);
47 }
48 
49 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
50 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
51 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)52 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
54   return b.release();
55 }
56 
57 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const string & tag)58 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
59     std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
60   static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
61                     std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
62                 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
63   VLOG(2) << "RegisterInternal("
64           << "tag: \"" << tag << "\" with " << replicated_buffers.size()
65           << " shaped_buffers.";
66   for (const auto& shaped_buffer : replicated_buffers) {
67     VLOG(2) << "shaped_buffer:" << shaped_buffer;
68     if (shaped_buffer.platform() != backend_->platform()) {
69       return InvalidArgument(
70           "AllocationTracker for platform %s cannot register buffer from "
71           "platform %s",
72           backend_->platform()->Name(), shaped_buffer.platform()->Name());
73     }
74   }
75 
76   int64 handle = next_handle_++;
77   for (auto& shaped_buffer : replicated_buffers) {
78     std::vector<ShapeIndex> shape_indices;
79     ShapeUtil::ForEachSubshape(
80         shaped_buffer.on_device_shape(),
81         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
82           shape_indices.push_back(index);
83         });
84     // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
85     // them.
86     for (const ShapeIndex& index : shape_indices) {
87       AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
88                                        shaped_buffer.device_ordinal());
89     }
90     // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
91     // into a regular ShapedBuffer, which is stored in
92     // handle_to_shaped_buffers_.
93     handle_to_shaped_buffers_[handle].emplace_back(
94         absl::make_unique<ShapedBuffer>(
95             ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
96   }
97 
98   GlobalDataHandle result;
99   result.set_handle(handle);
100   VLOG(2) << "handle: " << handle;
101   return result;
102 }
103 
Unregister(const GlobalDataHandle & data)104 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
105   tensorflow::mutex_lock lock(mutex_);
106   VLOG(2) << "Unregister("
107           << "handle: " << data.handle() << ")";
108   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
109                       ResolveInternal(data));
110   for (const auto& shaped_buffer : replicated_buffers) {
111     std::vector<ShapeIndex> shape_indices;
112     ShapeUtil::ForEachSubshape(
113         shaped_buffer->on_device_shape(),
114         [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
115           shape_indices.push_back(index);
116         });
117     for (const ShapeIndex& index : shape_indices) {
118       TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
119                                            shaped_buffer->device_ordinal()));
120     }
121   }
122   // Keep a nullptr as a tombstone for unregistered handles. This enables
123   // better error messages. That is, "handle has been deallocated" versus
124   // "handle does not exist".
125   auto it = handle_to_shaped_buffers_.find(data.handle());
126   if (it == handle_to_shaped_buffers_.end()) {
127     return NotFound("no allocation record for global data handle: %d",
128                     data.handle());
129   }
130   for (auto& shaped_buffer : it->second) {
131     shaped_buffer.reset();
132   }
133   return Status::OK();
134 }
135 
DeconstructTuple(const GlobalDataHandle & data)136 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
137     const GlobalDataHandle& data) {
138   tensorflow::mutex_lock lock(mutex_);
139 
140   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
141                       ResolveInternal(data));
142   // We only need to care about replica id 0 here, since the GlobalDataHandle is
143   // the same for all buffers across replicas.
144   const ShapedBuffer* shaped_buffer = replicated_buffers[0];
145   if (!shaped_buffer->on_host_shape().IsTuple()) {
146     return InvalidArgument("global data handle %d is not a tuple",
147                            data.handle());
148   }
149   // If the on-host representation is a tuple, then the on-device one should be
150   // as well.
151   TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
152 
153   if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
154     return Unimplemented("Deconstructing nested tuples is not implemented.");
155   }
156 
157   std::vector<GlobalDataHandle> element_handles;
158   for (int i = 0;
159        i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
160        ++i) {
161     auto element_buffer = ShapedBuffer(
162         ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
163         ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
164         shaped_buffer->platform(), shaped_buffer->device_ordinal());
165     element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
166                               /*index=*/{});
167     std::vector<ShapedBuffer> replicated_buffers;
168     replicated_buffers.push_back(std::move(element_buffer));
169     TF_ASSIGN_OR_RETURN(
170         GlobalDataHandle element_handle,
171         RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
172 
173     element_handles.push_back(element_handle);
174   }
175   return std::move(element_handles);
176 }
177 
Resolve(const GlobalDataHandle & data) const178 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
179     const GlobalDataHandle& data) const {
180   tensorflow::mutex_lock lock(mutex_);
181   return AllocationTracker::ResolveInternal(data);
182 }
183 
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const184 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
185     const GlobalDataHandle& data, int replica_id) const {
186   tensorflow::mutex_lock lock(mutex_);
187   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
188                       ResolveInternal(data));
189   if (replica_id >= replicated_buffers.size()) {
190     return InvalidArgument(
191         "Requesting buffer for replica %d, but found buffers only for %lu "
192         "replicas.",
193         replica_id, replicated_buffers.size());
194   }
195   return replicated_buffers[replica_id];
196 }
197 
ResolveInternal(const GlobalDataHandle & data) const198 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
199     const GlobalDataHandle& data) const {
200   VLOG(2) << "resolve:" << data.handle();
201   auto it = handle_to_shaped_buffers_.find(data.handle());
202   if (it == handle_to_shaped_buffers_.end()) {
203     return NotFound("no allocation record for global data handle: %d",
204                     data.handle());
205   }
206   std::vector<const ShapedBuffer*> replicated_buffers;
207   for (const auto& shaped_buffer : it->second) {
208     if (shaped_buffer == nullptr) {
209       return InvalidArgument("global data handle %d was previously deallocated",
210                              data.handle());
211     }
212     replicated_buffers.push_back(shaped_buffer.get());
213   }
214 
215   return replicated_buffers;
216 }
217 
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)218 void AllocationTracker::AddAllocationOrIncrementRefCount(
219     se::DeviceMemoryBase device_memory, int device_ordinal) {
220   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
221   auto it = allocation_map.find(device_memory.opaque());
222   if (it == allocation_map.end()) {
223     allocation_map[device_memory.opaque()] = {
224         OwningDeviceMemory(device_memory, device_ordinal,
225                            backend_->memory_allocator()),
226         /*ref_count=*/1};
227   } else {
228     it->second.ref_count++;
229   }
230 }
231 
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)232 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
233                                             int device_ordinal) {
234   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
235   auto it = allocation_map.find(device_memory.opaque());
236   TF_RET_CHECK(it != allocation_map.end());
237   Allocation& allocation = it->second;
238   TF_RET_CHECK(allocation.ref_count >= 1);
239   if (allocation.ref_count == 1) {
240     allocation.device_memory.Free();
241     allocation_map.erase(it);
242   } else {
243     allocation.ref_count--;
244   }
245   return Status::OK();
246 }
247 
248 }  // namespace xla
249