1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/io/cache.h"
17
18 #include <assert.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "tensorflow/core/platform/coding.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27
28 namespace table {
29
~Cache()30 Cache::~Cache() {}
31
32 namespace {
33
34 // LRU cache implementation
35 //
36 // Cache entries have an "in_cache" boolean indicating whether the cache has a
37 // reference on the entry. The only ways that this can become false without the
38 // entry being passed to its "deleter" are via Erase(), via Insert() when
39 // an element with a duplicate key is inserted, or on destruction of the cache.
40 //
41 // The cache keeps two linked lists of items in the cache. All items in the
42 // cache are in one list or the other, and never both. Items still referenced
43 // by clients but erased from the cache are in neither list. The lists are:
44 // - in-use: contains the items currently referenced by clients, in no
45 // particular order. (This list is used for invariant checking. If we
46 // removed the check, elements that would otherwise be on this list could be
47 // left as disconnected singleton lists.)
48 // - LRU: contains the items not currently referenced by clients, in LRU order
49 // Elements are moved between these lists by the Ref() and Unref() methods,
50 // when they detect an element in the cache acquiring or losing its only
51 // external reference.
52
53 // An entry is a variable length heap-allocated structure. Entries
54 // are kept in a circular doubly linked list ordered by access time.
55 struct LRUHandle {
56 void* value;
57 void (*deleter)(const Slice&, void* value);
58 LRUHandle* next_hash;
59 LRUHandle* next;
60 LRUHandle* prev;
61 size_t charge; // TODO(opt): Only allow uint32_t?
62 size_t key_length;
63 bool in_cache; // Whether entry is in the cache.
64 uint32_t refs; // References, including cache reference, if present.
65 uint32_t hash; // Hash of key(); used for fast sharding and comparisons
66 char key_data[1]; // Beginning of key
67
keytensorflow::table::__anon27b9c5620111::LRUHandle68 Slice key() const {
69 // next_ is only equal to this if the LRU handle is the list head of an
70 // empty list. List heads never have meaningful keys.
71 assert(next != this);
72
73 return Slice(key_data, key_length);
74 }
75 };
76
77 // We provide our own simple hash table since it removes a whole bunch
78 // of porting hacks and is also faster than some of the built-in hash
79 // table implementations in some of the compiler/runtime combinations
80 // we have tested. E.g., readrandom speeds up by ~5% over the g++
81 // 4.4.3's builtin hashtable.
82 class HandleTable {
83 public:
HandleTable()84 HandleTable() : length_(0), elems_(0), list_(nullptr) { Resize(); }
~HandleTable()85 ~HandleTable() { delete[] list_; }
86
Lookup(const Slice & key,uint32_t hash)87 LRUHandle* Lookup(const Slice& key, uint32_t hash) {
88 return *FindPointer(key, hash);
89 }
90
Insert(LRUHandle * h)91 LRUHandle* Insert(LRUHandle* h) {
92 LRUHandle** ptr = FindPointer(h->key(), h->hash);
93 LRUHandle* old = *ptr;
94 h->next_hash = (old == nullptr ? nullptr : old->next_hash);
95 *ptr = h;
96 if (old == nullptr) {
97 ++elems_;
98 if (elems_ > length_) {
99 // Since each cache entry is fairly large, we aim for a small
100 // average linked list length (<= 1).
101 Resize();
102 }
103 }
104 return old;
105 }
106
Remove(const Slice & key,uint32_t hash)107 LRUHandle* Remove(const Slice& key, uint32_t hash) {
108 LRUHandle** ptr = FindPointer(key, hash);
109 LRUHandle* result = *ptr;
110 if (result != nullptr) {
111 *ptr = result->next_hash;
112 --elems_;
113 }
114 return result;
115 }
116
117 private:
118 // The table consists of an array of buckets where each bucket is
119 // a linked list of cache entries that hash into the bucket.
120 uint32_t length_;
121 uint32_t elems_;
122 LRUHandle** list_;
123
124 // Return a pointer to slot that points to a cache entry that
125 // matches key/hash. If there is no such cache entry, return a
126 // pointer to the trailing slot in the corresponding linked list.
FindPointer(const Slice & key,uint32_t hash)127 LRUHandle** FindPointer(const Slice& key, uint32_t hash) {
128 LRUHandle** ptr = &list_[hash & (length_ - 1)];
129 while (*ptr != nullptr && ((*ptr)->hash != hash || key != (*ptr)->key())) {
130 ptr = &(*ptr)->next_hash;
131 }
132 return ptr;
133 }
134
Resize()135 void Resize() {
136 uint32_t new_length = 4;
137 while (new_length < elems_) {
138 new_length *= 2;
139 }
140 LRUHandle** new_list = new LRUHandle*[new_length];
141 memset(new_list, 0, sizeof(new_list[0]) * new_length);
142 uint32_t count = 0;
143 for (uint32_t i = 0; i < length_; i++) {
144 LRUHandle* h = list_[i];
145 while (h != nullptr) {
146 LRUHandle* next = h->next_hash;
147 uint32_t hash = h->hash;
148 LRUHandle** ptr = &new_list[hash & (new_length - 1)];
149 h->next_hash = *ptr;
150 *ptr = h;
151 h = next;
152 count++;
153 }
154 }
155 assert(elems_ == count);
156 delete[] list_;
157 list_ = new_list;
158 length_ = new_length;
159 }
160 };
161
162 // A single shard of sharded cache.
163 class LRUCache {
164 public:
165 LRUCache();
166 ~LRUCache();
167
168 // Separate from constructor so caller can easily make an array of LRUCache
SetCapacity(size_t capacity)169 void SetCapacity(size_t capacity) { capacity_ = capacity; }
170
171 // Like Cache methods, but with an extra "hash" parameter.
172 Cache::Handle* Insert(const Slice& key, uint32_t hash, void* value,
173 size_t charge,
174 void (*deleter)(const Slice& key, void* value));
175 Cache::Handle* Lookup(const Slice& key, uint32_t hash);
176 void Release(Cache::Handle* handle);
177 void Erase(const Slice& key, uint32_t hash);
178 void Prune();
TotalCharge() const179 size_t TotalCharge() const {
180 mutex_lock l(mutex_);
181 return usage_;
182 }
183
184 private:
185 void LRU_Remove(LRUHandle* e);
186 void LRU_Append(LRUHandle* list, LRUHandle* e);
187 void Ref(LRUHandle* e);
188 void Unref(LRUHandle* e);
189 bool FinishErase(LRUHandle* e) TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
190
191 // Initialized before use.
192 size_t capacity_;
193
194 // mutex_ protects the following state.
195 mutable mutex mutex_;
196 size_t usage_ TF_GUARDED_BY(mutex_);
197
198 // Dummy head of LRU list.
199 // lru.prev is newest entry, lru.next is oldest entry.
200 // Entries have refs==1 and in_cache==true.
201 LRUHandle lru_ TF_GUARDED_BY(mutex_);
202
203 // Dummy head of in-use list.
204 // Entries are in use by clients, and have refs >= 2 and in_cache==true.
205 LRUHandle in_use_ TF_GUARDED_BY(mutex_);
206
207 HandleTable table_ TF_GUARDED_BY(mutex_);
208 };
209
LRUCache()210 LRUCache::LRUCache() : capacity_(0), usage_(0) {
211 // Make empty circular linked lists.
212 lru_.next = &lru_;
213 lru_.prev = &lru_;
214 in_use_.next = &in_use_;
215 in_use_.prev = &in_use_;
216 }
217
~LRUCache()218 LRUCache::~LRUCache() {
219 assert(in_use_.next == &in_use_); // Error if caller has an unreleased handle
220 for (LRUHandle* e = lru_.next; e != &lru_;) {
221 LRUHandle* next = e->next;
222 assert(e->in_cache);
223 e->in_cache = false;
224 assert(e->refs == 1); // Invariant of lru_ list.
225 Unref(e);
226 e = next;
227 }
228 }
229
Ref(LRUHandle * e)230 void LRUCache::Ref(LRUHandle* e) {
231 if (e->refs == 1 && e->in_cache) { // If on lru_ list, move to in_use_ list.
232 LRU_Remove(e);
233 LRU_Append(&in_use_, e);
234 }
235 e->refs++;
236 }
237
Unref(LRUHandle * e)238 void LRUCache::Unref(LRUHandle* e) {
239 assert(e->refs > 0);
240 e->refs--;
241 if (e->refs == 0) { // Deallocate.
242 assert(!e->in_cache);
243 (*e->deleter)(e->key(), e->value);
244 free(e);
245 } else if (e->in_cache && e->refs == 1) {
246 // No longer in use; move to lru_ list.
247 LRU_Remove(e);
248 LRU_Append(&lru_, e);
249 }
250 }
251
LRU_Remove(LRUHandle * e)252 void LRUCache::LRU_Remove(LRUHandle* e) {
253 e->next->prev = e->prev;
254 e->prev->next = e->next;
255 }
256
LRU_Append(LRUHandle * list,LRUHandle * e)257 void LRUCache::LRU_Append(LRUHandle* list, LRUHandle* e) {
258 // Make "e" newest entry by inserting just before *list
259 e->next = list;
260 e->prev = list->prev;
261 e->prev->next = e;
262 e->next->prev = e;
263 }
264
Lookup(const Slice & key,uint32_t hash)265 Cache::Handle* LRUCache::Lookup(const Slice& key, uint32_t hash) {
266 mutex_lock l(mutex_);
267 LRUHandle* e = table_.Lookup(key, hash);
268 if (e != nullptr) {
269 Ref(e);
270 }
271 return reinterpret_cast<Cache::Handle*>(e);
272 }
273
Release(Cache::Handle * handle)274 void LRUCache::Release(Cache::Handle* handle) {
275 mutex_lock l(mutex_);
276 Unref(reinterpret_cast<LRUHandle*>(handle));
277 }
278
Insert(const Slice & key,uint32_t hash,void * value,size_t charge,void (* deleter)(const Slice & key,void * value))279 Cache::Handle* LRUCache::Insert(const Slice& key, uint32_t hash, void* value,
280 size_t charge,
281 void (*deleter)(const Slice& key,
282 void* value)) {
283 mutex_lock l(mutex_);
284
285 LRUHandle* e =
286 reinterpret_cast<LRUHandle*>(malloc(sizeof(LRUHandle) - 1 + key.size()));
287 e->value = value;
288 e->deleter = deleter;
289 e->charge = charge;
290 e->key_length = key.size();
291 e->hash = hash;
292 e->in_cache = false;
293 e->refs = 1; // for the returned handle.
294 memcpy(e->key_data, key.data(), key.size());
295
296 if (capacity_ > 0) {
297 e->refs++; // for the cache's reference.
298 e->in_cache = true;
299 LRU_Append(&in_use_, e);
300 usage_ += charge;
301 FinishErase(table_.Insert(e));
302 } else { // don't cache. (capacity_==0 is supported and turns off caching.)
303 // next is read by key() in an assert, so it must be initialized
304 e->next = nullptr;
305 }
306 while (usage_ > capacity_ && lru_.next != &lru_) {
307 LRUHandle* old = lru_.next;
308 assert(old->refs == 1);
309 bool erased = FinishErase(table_.Remove(old->key(), old->hash));
310 if (!erased) { // to avoid unused variable when compiled NDEBUG
311 assert(erased);
312 }
313 }
314
315 return reinterpret_cast<Cache::Handle*>(e);
316 }
317
318 // If e != nullptr, finish removing *e from the cache; it has already been
319 // removed from the hash table. Return whether e != nullptr.
FinishErase(LRUHandle * e)320 bool LRUCache::FinishErase(LRUHandle* e) {
321 if (e != nullptr) {
322 assert(e->in_cache);
323 LRU_Remove(e);
324 e->in_cache = false;
325 usage_ -= e->charge;
326 Unref(e);
327 }
328 return e != nullptr;
329 }
330
Erase(const Slice & key,uint32_t hash)331 void LRUCache::Erase(const Slice& key, uint32_t hash) {
332 mutex_lock l(mutex_);
333 FinishErase(table_.Remove(key, hash));
334 }
335
Prune()336 void LRUCache::Prune() {
337 mutex_lock l(mutex_);
338 while (lru_.next != &lru_) {
339 LRUHandle* e = lru_.next;
340 assert(e->refs == 1);
341 bool erased = FinishErase(table_.Remove(e->key(), e->hash));
342 if (!erased) { // to avoid unused variable when compiled NDEBUG
343 assert(erased);
344 }
345 }
346 }
347
348 static const int kNumShardBits = 4;
349 static const int kNumShards = 1 << kNumShardBits;
350
351 class ShardedLRUCache : public Cache {
352 private:
353 LRUCache shard_[kNumShards];
354 mutex id_mutex_;
355 uint64_t last_id_;
356
HashSlice(const Slice & s)357 static inline uint32_t HashSlice(const Slice& s) {
358 return Hash(s.data(), s.size(), 0);
359 }
360
Shard(uint32_t hash)361 static uint32_t Shard(uint32_t hash) { return hash >> (32 - kNumShardBits); }
362
363 public:
ShardedLRUCache(size_t capacity)364 explicit ShardedLRUCache(size_t capacity) : last_id_(0) {
365 const size_t per_shard = (capacity + (kNumShards - 1)) / kNumShards;
366 for (int s = 0; s < kNumShards; s++) {
367 shard_[s].SetCapacity(per_shard);
368 }
369 }
~ShardedLRUCache()370 ~ShardedLRUCache() override {}
Insert(const Slice & key,void * value,size_t charge,void (* deleter)(const Slice & key,void * value))371 Handle* Insert(const Slice& key, void* value, size_t charge,
372 void (*deleter)(const Slice& key, void* value)) override {
373 const uint32_t hash = HashSlice(key);
374 return shard_[Shard(hash)].Insert(key, hash, value, charge, deleter);
375 }
Lookup(const Slice & key)376 Handle* Lookup(const Slice& key) override {
377 const uint32_t hash = HashSlice(key);
378 return shard_[Shard(hash)].Lookup(key, hash);
379 }
Release(Handle * handle)380 void Release(Handle* handle) override {
381 LRUHandle* h = reinterpret_cast<LRUHandle*>(handle);
382 shard_[Shard(h->hash)].Release(handle);
383 }
Erase(const Slice & key)384 void Erase(const Slice& key) override {
385 const uint32_t hash = HashSlice(key);
386 shard_[Shard(hash)].Erase(key, hash);
387 }
Value(Handle * handle)388 void* Value(Handle* handle) override {
389 return reinterpret_cast<LRUHandle*>(handle)->value;
390 }
NewId()391 uint64_t NewId() override {
392 mutex_lock l(id_mutex_);
393 return ++(last_id_);
394 }
Prune()395 void Prune() override {
396 for (int s = 0; s < kNumShards; s++) {
397 shard_[s].Prune();
398 }
399 }
TotalCharge() const400 size_t TotalCharge() const override {
401 size_t total = 0;
402 for (int s = 0; s < kNumShards; s++) {
403 total += shard_[s].TotalCharge();
404 }
405 return total;
406 }
407
408 private:
409 // TODO(byronyi): Figure out why Hash32 fails EvictionPolicy test.
Hash(const char * data,size_t n,uint32_t seed)410 static uint32_t Hash(const char* data, size_t n, uint32_t seed) {
411 // Similar to murmur hash
412 const uint32_t m = 0xc6a4a793;
413 const uint32_t r = 24;
414 const char* limit = data + n;
415 uint32_t h = seed ^ (n * m);
416
417 // Pick up four bytes at a time
418 while (data + 4 <= limit) {
419 uint32_t w = core::DecodeFixed32(data);
420 data += 4;
421 h += w;
422 h *= m;
423 h ^= (h >> 16);
424 }
425
426 // Pick up remaining bytes
427 switch (limit - data) {
428 case 3:
429 h += static_cast<uint8_t>(data[2]) << 16;
430 ABSL_FALLTHROUGH_INTENDED;
431 case 2:
432 h += static_cast<uint8_t>(data[1]) << 8;
433 ABSL_FALLTHROUGH_INTENDED;
434 case 1:
435 h += static_cast<uint8_t>(data[0]);
436 h *= m;
437 h ^= (h >> r);
438 break;
439 }
440 return h;
441 }
442 };
443
444 } // end anonymous namespace
445
NewLRUCache(size_t capacity)446 Cache* NewLRUCache(size_t capacity) { return new ShardedLRUCache(capacity); }
447
448 } // namespace table
449
450 } // namespace tensorflow
451