1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/io/cache.h"
17 
18 #include <assert.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22 
23 #include "tensorflow/core/platform/coding.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 
28 namespace table {
29 
~Cache()30 Cache::~Cache() {}
31 
32 namespace {
33 
34 // LRU cache implementation
35 //
36 // Cache entries have an "in_cache" boolean indicating whether the cache has a
37 // reference on the entry.  The only ways that this can become false without the
38 // entry being passed to its "deleter" are via Erase(), via Insert() when
39 // an element with a duplicate key is inserted, or on destruction of the cache.
40 //
41 // The cache keeps two linked lists of items in the cache.  All items in the
42 // cache are in one list or the other, and never both.  Items still referenced
43 // by clients but erased from the cache are in neither list.  The lists are:
44 // - in-use:  contains the items currently referenced by clients, in no
45 //   particular order.  (This list is used for invariant checking.  If we
46 //   removed the check, elements that would otherwise be on this list could be
47 //   left as disconnected singleton lists.)
48 // - LRU:  contains the items not currently referenced by clients, in LRU order
49 // Elements are moved between these lists by the Ref() and Unref() methods,
50 // when they detect an element in the cache acquiring or losing its only
51 // external reference.
52 
53 // An entry is a variable length heap-allocated structure.  Entries
54 // are kept in a circular doubly linked list ordered by access time.
55 struct LRUHandle {
56   void* value;
57   void (*deleter)(const Slice&, void* value);
58   LRUHandle* next_hash;
59   LRUHandle* next;
60   LRUHandle* prev;
61   size_t charge;  // TODO(opt): Only allow uint32_t?
62   size_t key_length;
63   bool in_cache;     // Whether entry is in the cache.
64   uint32_t refs;     // References, including cache reference, if present.
65   uint32_t hash;     // Hash of key(); used for fast sharding and comparisons
66   char key_data[1];  // Beginning of key
67 
keytensorflow::table::__anon27b9c5620111::LRUHandle68   Slice key() const {
69     // next_ is only equal to this if the LRU handle is the list head of an
70     // empty list. List heads never have meaningful keys.
71     assert(next != this);
72 
73     return Slice(key_data, key_length);
74   }
75 };
76 
77 // We provide our own simple hash table since it removes a whole bunch
78 // of porting hacks and is also faster than some of the built-in hash
79 // table implementations in some of the compiler/runtime combinations
80 // we have tested.  E.g., readrandom speeds up by ~5% over the g++
81 // 4.4.3's builtin hashtable.
82 class HandleTable {
83  public:
HandleTable()84   HandleTable() : length_(0), elems_(0), list_(nullptr) { Resize(); }
~HandleTable()85   ~HandleTable() { delete[] list_; }
86 
Lookup(const Slice & key,uint32_t hash)87   LRUHandle* Lookup(const Slice& key, uint32_t hash) {
88     return *FindPointer(key, hash);
89   }
90 
Insert(LRUHandle * h)91   LRUHandle* Insert(LRUHandle* h) {
92     LRUHandle** ptr = FindPointer(h->key(), h->hash);
93     LRUHandle* old = *ptr;
94     h->next_hash = (old == nullptr ? nullptr : old->next_hash);
95     *ptr = h;
96     if (old == nullptr) {
97       ++elems_;
98       if (elems_ > length_) {
99         // Since each cache entry is fairly large, we aim for a small
100         // average linked list length (<= 1).
101         Resize();
102       }
103     }
104     return old;
105   }
106 
Remove(const Slice & key,uint32_t hash)107   LRUHandle* Remove(const Slice& key, uint32_t hash) {
108     LRUHandle** ptr = FindPointer(key, hash);
109     LRUHandle* result = *ptr;
110     if (result != nullptr) {
111       *ptr = result->next_hash;
112       --elems_;
113     }
114     return result;
115   }
116 
117  private:
118   // The table consists of an array of buckets where each bucket is
119   // a linked list of cache entries that hash into the bucket.
120   uint32_t length_;
121   uint32_t elems_;
122   LRUHandle** list_;
123 
124   // Return a pointer to slot that points to a cache entry that
125   // matches key/hash.  If there is no such cache entry, return a
126   // pointer to the trailing slot in the corresponding linked list.
FindPointer(const Slice & key,uint32_t hash)127   LRUHandle** FindPointer(const Slice& key, uint32_t hash) {
128     LRUHandle** ptr = &list_[hash & (length_ - 1)];
129     while (*ptr != nullptr && ((*ptr)->hash != hash || key != (*ptr)->key())) {
130       ptr = &(*ptr)->next_hash;
131     }
132     return ptr;
133   }
134 
Resize()135   void Resize() {
136     uint32_t new_length = 4;
137     while (new_length < elems_) {
138       new_length *= 2;
139     }
140     LRUHandle** new_list = new LRUHandle*[new_length];
141     memset(new_list, 0, sizeof(new_list[0]) * new_length);
142     uint32_t count = 0;
143     for (uint32_t i = 0; i < length_; i++) {
144       LRUHandle* h = list_[i];
145       while (h != nullptr) {
146         LRUHandle* next = h->next_hash;
147         uint32_t hash = h->hash;
148         LRUHandle** ptr = &new_list[hash & (new_length - 1)];
149         h->next_hash = *ptr;
150         *ptr = h;
151         h = next;
152         count++;
153       }
154     }
155     assert(elems_ == count);
156     delete[] list_;
157     list_ = new_list;
158     length_ = new_length;
159   }
160 };
161 
162 // A single shard of sharded cache.
163 class LRUCache {
164  public:
165   LRUCache();
166   ~LRUCache();
167 
168   // Separate from constructor so caller can easily make an array of LRUCache
SetCapacity(size_t capacity)169   void SetCapacity(size_t capacity) { capacity_ = capacity; }
170 
171   // Like Cache methods, but with an extra "hash" parameter.
172   Cache::Handle* Insert(const Slice& key, uint32_t hash, void* value,
173                         size_t charge,
174                         void (*deleter)(const Slice& key, void* value));
175   Cache::Handle* Lookup(const Slice& key, uint32_t hash);
176   void Release(Cache::Handle* handle);
177   void Erase(const Slice& key, uint32_t hash);
178   void Prune();
TotalCharge() const179   size_t TotalCharge() const {
180     mutex_lock l(mutex_);
181     return usage_;
182   }
183 
184  private:
185   void LRU_Remove(LRUHandle* e);
186   void LRU_Append(LRUHandle* list, LRUHandle* e);
187   void Ref(LRUHandle* e);
188   void Unref(LRUHandle* e);
189   bool FinishErase(LRUHandle* e) TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
190 
191   // Initialized before use.
192   size_t capacity_;
193 
194   // mutex_ protects the following state.
195   mutable mutex mutex_;
196   size_t usage_ TF_GUARDED_BY(mutex_);
197 
198   // Dummy head of LRU list.
199   // lru.prev is newest entry, lru.next is oldest entry.
200   // Entries have refs==1 and in_cache==true.
201   LRUHandle lru_ TF_GUARDED_BY(mutex_);
202 
203   // Dummy head of in-use list.
204   // Entries are in use by clients, and have refs >= 2 and in_cache==true.
205   LRUHandle in_use_ TF_GUARDED_BY(mutex_);
206 
207   HandleTable table_ TF_GUARDED_BY(mutex_);
208 };
209 
LRUCache()210 LRUCache::LRUCache() : capacity_(0), usage_(0) {
211   // Make empty circular linked lists.
212   lru_.next = &lru_;
213   lru_.prev = &lru_;
214   in_use_.next = &in_use_;
215   in_use_.prev = &in_use_;
216 }
217 
~LRUCache()218 LRUCache::~LRUCache() {
219   assert(in_use_.next == &in_use_);  // Error if caller has an unreleased handle
220   for (LRUHandle* e = lru_.next; e != &lru_;) {
221     LRUHandle* next = e->next;
222     assert(e->in_cache);
223     e->in_cache = false;
224     assert(e->refs == 1);  // Invariant of lru_ list.
225     Unref(e);
226     e = next;
227   }
228 }
229 
Ref(LRUHandle * e)230 void LRUCache::Ref(LRUHandle* e) {
231   if (e->refs == 1 && e->in_cache) {  // If on lru_ list, move to in_use_ list.
232     LRU_Remove(e);
233     LRU_Append(&in_use_, e);
234   }
235   e->refs++;
236 }
237 
Unref(LRUHandle * e)238 void LRUCache::Unref(LRUHandle* e) {
239   assert(e->refs > 0);
240   e->refs--;
241   if (e->refs == 0) {  // Deallocate.
242     assert(!e->in_cache);
243     (*e->deleter)(e->key(), e->value);
244     free(e);
245   } else if (e->in_cache && e->refs == 1) {
246     // No longer in use; move to lru_ list.
247     LRU_Remove(e);
248     LRU_Append(&lru_, e);
249   }
250 }
251 
LRU_Remove(LRUHandle * e)252 void LRUCache::LRU_Remove(LRUHandle* e) {
253   e->next->prev = e->prev;
254   e->prev->next = e->next;
255 }
256 
LRU_Append(LRUHandle * list,LRUHandle * e)257 void LRUCache::LRU_Append(LRUHandle* list, LRUHandle* e) {
258   // Make "e" newest entry by inserting just before *list
259   e->next = list;
260   e->prev = list->prev;
261   e->prev->next = e;
262   e->next->prev = e;
263 }
264 
Lookup(const Slice & key,uint32_t hash)265 Cache::Handle* LRUCache::Lookup(const Slice& key, uint32_t hash) {
266   mutex_lock l(mutex_);
267   LRUHandle* e = table_.Lookup(key, hash);
268   if (e != nullptr) {
269     Ref(e);
270   }
271   return reinterpret_cast<Cache::Handle*>(e);
272 }
273 
Release(Cache::Handle * handle)274 void LRUCache::Release(Cache::Handle* handle) {
275   mutex_lock l(mutex_);
276   Unref(reinterpret_cast<LRUHandle*>(handle));
277 }
278 
Insert(const Slice & key,uint32_t hash,void * value,size_t charge,void (* deleter)(const Slice & key,void * value))279 Cache::Handle* LRUCache::Insert(const Slice& key, uint32_t hash, void* value,
280                                 size_t charge,
281                                 void (*deleter)(const Slice& key,
282                                                 void* value)) {
283   mutex_lock l(mutex_);
284 
285   LRUHandle* e =
286       reinterpret_cast<LRUHandle*>(malloc(sizeof(LRUHandle) - 1 + key.size()));
287   e->value = value;
288   e->deleter = deleter;
289   e->charge = charge;
290   e->key_length = key.size();
291   e->hash = hash;
292   e->in_cache = false;
293   e->refs = 1;  // for the returned handle.
294   memcpy(e->key_data, key.data(), key.size());
295 
296   if (capacity_ > 0) {
297     e->refs++;  // for the cache's reference.
298     e->in_cache = true;
299     LRU_Append(&in_use_, e);
300     usage_ += charge;
301     FinishErase(table_.Insert(e));
302   } else {  // don't cache. (capacity_==0 is supported and turns off caching.)
303     // next is read by key() in an assert, so it must be initialized
304     e->next = nullptr;
305   }
306   while (usage_ > capacity_ && lru_.next != &lru_) {
307     LRUHandle* old = lru_.next;
308     assert(old->refs == 1);
309     bool erased = FinishErase(table_.Remove(old->key(), old->hash));
310     if (!erased) {  // to avoid unused variable when compiled NDEBUG
311       assert(erased);
312     }
313   }
314 
315   return reinterpret_cast<Cache::Handle*>(e);
316 }
317 
318 // If e != nullptr, finish removing *e from the cache; it has already been
319 // removed from the hash table.  Return whether e != nullptr.
FinishErase(LRUHandle * e)320 bool LRUCache::FinishErase(LRUHandle* e) {
321   if (e != nullptr) {
322     assert(e->in_cache);
323     LRU_Remove(e);
324     e->in_cache = false;
325     usage_ -= e->charge;
326     Unref(e);
327   }
328   return e != nullptr;
329 }
330 
Erase(const Slice & key,uint32_t hash)331 void LRUCache::Erase(const Slice& key, uint32_t hash) {
332   mutex_lock l(mutex_);
333   FinishErase(table_.Remove(key, hash));
334 }
335 
Prune()336 void LRUCache::Prune() {
337   mutex_lock l(mutex_);
338   while (lru_.next != &lru_) {
339     LRUHandle* e = lru_.next;
340     assert(e->refs == 1);
341     bool erased = FinishErase(table_.Remove(e->key(), e->hash));
342     if (!erased) {  // to avoid unused variable when compiled NDEBUG
343       assert(erased);
344     }
345   }
346 }
347 
348 static const int kNumShardBits = 4;
349 static const int kNumShards = 1 << kNumShardBits;
350 
351 class ShardedLRUCache : public Cache {
352  private:
353   LRUCache shard_[kNumShards];
354   mutex id_mutex_;
355   uint64_t last_id_;
356 
HashSlice(const Slice & s)357   static inline uint32_t HashSlice(const Slice& s) {
358     return Hash(s.data(), s.size(), 0);
359   }
360 
Shard(uint32_t hash)361   static uint32_t Shard(uint32_t hash) { return hash >> (32 - kNumShardBits); }
362 
363  public:
ShardedLRUCache(size_t capacity)364   explicit ShardedLRUCache(size_t capacity) : last_id_(0) {
365     const size_t per_shard = (capacity + (kNumShards - 1)) / kNumShards;
366     for (int s = 0; s < kNumShards; s++) {
367       shard_[s].SetCapacity(per_shard);
368     }
369   }
~ShardedLRUCache()370   ~ShardedLRUCache() override {}
Insert(const Slice & key,void * value,size_t charge,void (* deleter)(const Slice & key,void * value))371   Handle* Insert(const Slice& key, void* value, size_t charge,
372                  void (*deleter)(const Slice& key, void* value)) override {
373     const uint32_t hash = HashSlice(key);
374     return shard_[Shard(hash)].Insert(key, hash, value, charge, deleter);
375   }
Lookup(const Slice & key)376   Handle* Lookup(const Slice& key) override {
377     const uint32_t hash = HashSlice(key);
378     return shard_[Shard(hash)].Lookup(key, hash);
379   }
Release(Handle * handle)380   void Release(Handle* handle) override {
381     LRUHandle* h = reinterpret_cast<LRUHandle*>(handle);
382     shard_[Shard(h->hash)].Release(handle);
383   }
Erase(const Slice & key)384   void Erase(const Slice& key) override {
385     const uint32_t hash = HashSlice(key);
386     shard_[Shard(hash)].Erase(key, hash);
387   }
Value(Handle * handle)388   void* Value(Handle* handle) override {
389     return reinterpret_cast<LRUHandle*>(handle)->value;
390   }
NewId()391   uint64_t NewId() override {
392     mutex_lock l(id_mutex_);
393     return ++(last_id_);
394   }
Prune()395   void Prune() override {
396     for (int s = 0; s < kNumShards; s++) {
397       shard_[s].Prune();
398     }
399   }
TotalCharge() const400   size_t TotalCharge() const override {
401     size_t total = 0;
402     for (int s = 0; s < kNumShards; s++) {
403       total += shard_[s].TotalCharge();
404     }
405     return total;
406   }
407 
408  private:
409   // TODO(byronyi): Figure out why Hash32 fails EvictionPolicy test.
Hash(const char * data,size_t n,uint32_t seed)410   static uint32_t Hash(const char* data, size_t n, uint32_t seed) {
411     // Similar to murmur hash
412     const uint32_t m = 0xc6a4a793;
413     const uint32_t r = 24;
414     const char* limit = data + n;
415     uint32_t h = seed ^ (n * m);
416 
417     // Pick up four bytes at a time
418     while (data + 4 <= limit) {
419       uint32_t w = core::DecodeFixed32(data);
420       data += 4;
421       h += w;
422       h *= m;
423       h ^= (h >> 16);
424     }
425 
426     // Pick up remaining bytes
427     switch (limit - data) {
428       case 3:
429         h += static_cast<uint8_t>(data[2]) << 16;
430         ABSL_FALLTHROUGH_INTENDED;
431       case 2:
432         h += static_cast<uint8_t>(data[1]) << 8;
433         ABSL_FALLTHROUGH_INTENDED;
434       case 1:
435         h += static_cast<uint8_t>(data[0]);
436         h *= m;
437         h ^= (h >> r);
438         break;
439     }
440     return h;
441   }
442 };
443 
444 }  // end anonymous namespace
445 
NewLRUCache(size_t capacity)446 Cache* NewLRUCache(size_t capacity) { return new ShardedLRUCache(capacity); }
447 
448 }  // namespace table
449 
450 }  // namespace tensorflow
451