1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xrt/xrt_metrics.h"
24 #include "tensorflow/core/lib/monitoring/timed.h"
25 #include "tensorflow/core/lib/random/random.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 // We use kDeviceBits to store the device ordinal in the handle. We store the
32 // device in the upper part of the int64 handle to make sure the random bits are
33 // in the lower part which is better when storing the handle as a key for
34 // unordered maps.
35 const int kDeviceBits = 12;
36 
MakeDeviceHandle(int64 device_ordinal,int64 rnd_value)37 int64 MakeDeviceHandle(int64 device_ordinal, int64 rnd_value) {
38   const int64 kUidMask = (static_cast<int64>(1) << (64 - kDeviceBits)) - 1;
39   return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
40 }
41 
GetDeviceFromHandle(int64 handle)42 int GetDeviceFromHandle(int64 handle) {
43   return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
44 }
45 
46 }  // namespace
47 
48 class XRTMemoryManager::DeviceContext {
49   struct Alloc {
Alloctensorflow::XRTMemoryManager::DeviceContext::Alloc50     explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
51         : tuple(std::move(tuple)) {}
52 
53     RefPtr<XRTTupleAllocation> tuple;
54   };
55 
56   using AllocList = std::list<Alloc>;
57 
58  public:
Register(RefPtr<XRTTupleAllocation> tuple)59   int64 Register(RefPtr<XRTTupleAllocation> tuple) {
60     while (true) {
61       int64 handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
62       mutex_lock lock(lock_);
63       allocs_.emplace_front(tuple);
64       if (alloc_map_.emplace(handle, allocs_.begin()).second) {
65         return handle;
66       }
67       // The chances of hitting an existing handle are so remote, it is much
68       // more convenient to add to the list before, and eventually removing.
69       allocs_.erase(allocs_.begin());
70     }
71   }
72 
Release(int64 handle)73   bool Release(int64 handle) {
74     mutex_lock lock(lock_);
75     auto it = alloc_map_.find(handle);
76     if (it == alloc_map_.end()) {
77       return false;
78     }
79     allocs_.erase(it->second);
80     alloc_map_.erase(it);
81     return true;
82   }
83 
Lookup(int64 handle)84   RefPtr<XRTTupleAllocation> Lookup(int64 handle) {
85     mutex_lock lock(lock_);
86     auto it = alloc_map_.find(handle);
87     if (it == alloc_map_.end()) {
88       return nullptr;
89     }
90     // LRU
91     allocs_.splice(allocs_.begin(), allocs_, it->second);
92     return it->second->tuple;
93   }
94 
Clear()95   void Clear() {
96     mutex_lock lock(lock_);
97     alloc_map_.clear();
98     allocs_.clear();
99   }
100 
CompactAllocations(XRTMemoryManager * memory_manager,xla::Backend * backend)101   Status CompactAllocations(XRTMemoryManager* memory_manager,
102                             xla::Backend* backend) {
103     profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations",
104                                /*level=*/2);
105     auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell());
106     VLOG(4) << "CompactAllocations started";
107     mutex_lock lock(lock_);
108     Status status;
109     std::vector<AllocList::iterator> swapped;
110     // We are swapping out from the most recently used allocations. This is
111     // desirable since the most recently used will be finding themselves at the
112     // bottom of the allocation space. Since these are more likely to be pinned
113     // allocations, a further trim done by following TryFreeMemory() call will
114     // eventually drop the higher located allocations, with better chance of
115     // reducing fragmentation.
116     // Also, by swapping out the pinned allocations first, those will also be
117     // the first to be restored, and hence if we will ever find OOM on the way
118     // out, we would more likely be swapping in not pinned ones.
119     for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
120       // We are compacting all the allocations, so we will temporarily swap out
121       // even pinned allocations.
122       auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
123       if (!swap_result_or.ok()) {
124         status = swap_result_or.status();
125         break;
126       }
127       if (swap_result_or.ValueOrDie()) {
128         swapped.push_back(it);
129       }
130     }
131     // At this point we have released all the device memory we could release.
132     // Load back the tuple allocations we have swapped out above.
133     for (auto& it : swapped) {
134       auto swap_result_or = it->tuple->SwapIn(memory_manager, backend);
135       if (!swap_result_or.ok()) {
136         // If we failed to restored a pinned allocation, better to CHECK here
137         // than wondering why XRTTupleAllocation calls fail with errors about
138         // missing buffers.
139         CHECK(!it->tuple->IsPinned());  // Crash OK
140         if (status.ok()) {
141           status = swap_result_or.status();
142         }
143       }
144     }
145     VLOG(4) << "CompactAllocations finished: " << status;
146     return status;
147   }
148 
149   // Tries to free size bytes by freeing some unpinned device memory. Returns
150   // the amount of memory which was able to free.
TryFreeMemory(xla::Backend * backend,size_t size)151   xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
152     profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2);
153     auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell());
154     mutex_lock lock(lock_);
155     size_t swapped_size = 0;
156     for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
157       TF_ASSIGN_OR_RETURN(bool swap_result,
158                           it->tuple->SwapOut(backend, /*swap_pinned=*/false));
159       if (swap_result) {
160         swapped_size += it->tuple->GetDeviceMemorySize();
161         if (swapped_size >= size) {
162           break;
163         }
164       }
165     }
166     VLOG(3) << "Swapped out " << swapped_size << " bytes";
167     return swapped_size;
168   }
169 
170  private:
CreateUid()171   static int64 CreateUid() {
172     int64 uid;
173     do {
174       uid = random::New64() & INT64_MAX;
175     } while (uid == InvalidKey());
176     return uid;
177   }
178 
179   // We store Alloc records inside an std::list<Alloc> so we can LRU it, and
180   // store the list iterators within the handle map, as list iterators don't get
181   // invalidated by (other elements) removals or position swaps.
182   mutex lock_;
183   AllocList allocs_;
184   std::unordered_map<int64, AllocList::iterator> alloc_map_;
185 };
186 
WorkingSet(RefPtr<XRTMemoryManager> memory_manager)187 XRTMemoryManager::WorkingSet::WorkingSet(
188     RefPtr<XRTMemoryManager> memory_manager)
189     : memory_manager_(std::move(memory_manager)) {}
190 
~WorkingSet()191 XRTMemoryManager::WorkingSet::~WorkingSet() {
192   for (auto& tuple : pinned_tuples_) {
193     tuple->Unpin();
194   }
195 }
196 
LookupAndPin(xla::Backend * backend,int64 handle)197 Status XRTMemoryManager::WorkingSet::LookupAndPin(xla::Backend* backend,
198                                                   int64 handle) {
199   TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
200   TF_RETURN_IF_ERROR(
201       tuple->PinAndSwapIn(memory_manager_.get(), backend).status());
202   pinned_tuples_.push_back(std::move(tuple));
203   return Status::OK();
204 }
205 
Get(ResourceMgr * rm)206 /* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
207   static string* container = new string("XrtState");
208   static string* name = new string("MemoryManager");
209   XRTMemoryManager* memory_manager = nullptr;
210   TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
211       *container, *name, &memory_manager, [](XRTMemoryManager** ret) {
212         *ret = new XRTMemoryManager();
213         return Status::OK();
214       }));
215   return memory_manager;
216 }
217 
Register(RefPtr<XRTTupleAllocation> tuple)218 int64 XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
219   DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
220                                                    /*create_if_missing=*/true);
221   return device_context->Register(std::move(tuple));
222 }
223 
Lookup(int64 handle)224 xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
225     int64 handle) {
226   int device_ordinal = GetDeviceFromHandle(handle);
227   DeviceContext* device_context = GetDeviceContext(device_ordinal,
228                                                    /*create_if_missing=*/false);
229   if (device_context == nullptr) {
230     return errors::NotFound("XRT memory handle not found: ", handle);
231   }
232   RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
233   if (tuple == nullptr) {
234     return errors::NotFound("XRT memory handle not found: ", handle);
235   }
236   return std::move(tuple);
237 }
238 
Release(int64 handle)239 Status XRTMemoryManager::Release(int64 handle) {
240   int device_ordinal = GetDeviceFromHandle(handle);
241   DeviceContext* device_context = GetDeviceContext(device_ordinal,
242                                                    /*create_if_missing=*/false);
243   if (device_context == nullptr || !device_context->Release(handle)) {
244     return errors::NotFound("XRT memory handle not found: ", handle);
245   }
246   return Status::OK();
247 }
248 
CompactAllocations(xla::Backend * backend,int device_ordinal)249 Status XRTMemoryManager::CompactAllocations(xla::Backend* backend,
250                                             int device_ordinal) {
251   DeviceContext* device_context = GetDeviceContext(device_ordinal,
252                                                    /*create_if_missing=*/false);
253   return device_context != nullptr
254              ? device_context->CompactAllocations(this, backend)
255              : Status::OK();
256 }
257 
ReleaseAllAllocations()258 void XRTMemoryManager::ReleaseAllAllocations() {
259   mutex_lock lock(lock_);
260   for (auto& device_context : device_contexts_) {
261     if (device_context != nullptr) {
262       device_context->Clear();
263     }
264   }
265 }
266 
Allocate(xla::Backend * backend,int device_ordinal,size_t size)267 xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
268     xla::Backend* backend, int device_ordinal, size_t size) {
269   se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
270   auto memory_or =
271       allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
272   if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
273     VLOG(4) << "Allocate of " << size << " bytes failed on device "
274             << device_ordinal;
275 
276     DeviceContext* device_context =
277         GetDeviceContext(device_ordinal,
278                          /*create_if_missing=*/false);
279     if (device_context != nullptr) {
280       Status status = device_context->TryFreeMemory(backend, size).status();
281       if (status.ok()) {
282         // As long as there is no error, we still try again the allocation, even
283         // if the TryFreeMemory() call ended up freeing less memory than the
284         // required size. Fragmentation could make the memory allocation succeed
285         // even if the freed memory is indeed lower.
286         memory_or = allocator->Allocate(device_ordinal, size,
287                                         /*retry_on_failure=*/false);
288       } else if (status.code() != error::RESOURCE_EXHAUSTED) {
289         VLOG(4) << "Allocate of " << size << " bytes on device "
290                 << device_ordinal << ": " << status;
291         return status;
292       }
293     }
294   }
295   return memory_or;
296 }
297 
DebugString() const298 string XRTMemoryManager::DebugString() const {
299   // We might want to emit more detailed information here, like per device
300   // memory allocations.
301   return "XRTMemoryManager";
302 }
303 
GetDeviceContext(int device_ordinal,bool create_if_missing)304 XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
305     int device_ordinal, bool create_if_missing) {
306   mutex_lock lock(lock_);
307   if (device_ordinal >= device_contexts_.size()) {
308     if (!create_if_missing) {
309       return nullptr;
310     }
311     device_contexts_.resize(device_ordinal + 1);
312   }
313   DeviceContext* device_context = device_contexts_[device_ordinal].get();
314   if (device_context == nullptr && create_if_missing) {
315     device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
316     device_context = device_contexts_[device_ordinal].get();
317   }
318   return device_context;
319 }
320 
TryFreeMemoryStep(MemoryReclaimContext * mrctx,const Status & status)321 Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
322                                            const Status& status) {
323   DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
324                                                    /*create_if_missing=*/false);
325   if (device_context == nullptr) {
326     return status;
327   }
328   if (!mrctx->done_freeing) {
329     // If the caller passed us a zero requested_free_size, we try to free chunks
330     // of kMaxFreeSize memory, until either the run function succeeds, or we run
331     // out of freeable memory.
332     const size_t kMaxFreeSize = 1000000000;
333     size_t free_size =
334         (mrctx->requested_free_size > 0)
335             ? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
336                                kMaxFreeSize)
337             : kMaxFreeSize;
338     if (free_size > 0) {
339       auto free_size_or =
340           device_context->TryFreeMemory(mrctx->backend, free_size);
341       if (!free_size_or.ok()) {
342         return status;
343       }
344       size_t size = free_size_or.ValueOrDie();
345       mrctx->free_size += size;
346       if (size > 0) {
347         return Status::OK();
348       }
349     }
350     mrctx->done_freeing = true;
351   }
352   if (!mrctx->done_compacting) {
353     mrctx->done_compacting = true;
354     if (device_context->CompactAllocations(this, mrctx->backend).ok()) {
355       return Status::OK();
356     }
357   }
358   return status;
359 }
360 
361 }  // namespace tensorflow
362