1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
17 #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
18 
19 #include <string.h>
20 #include <utility>
21 #include "tensorflow/core/platform/prefetch.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 namespace gtl {
26 namespace internal {
27 
28 // Internal representation for FlatMap and FlatSet.
29 //
30 // The representation is an open-addressed hash table.  Conceptually,
31 // the representation is a flat array of entries.  However we
32 // structure it as an array of buckets where each bucket holds
33 // kWidth entries along with metadata for the kWidth entries.  The
34 // metadata marker is
35 //
36 //  (a) kEmpty: the entry is empty
37 //  (b) kDeleted: the entry has been deleted
38 //  (c) other: the entry is occupied and has low-8 bits of its hash.
39 //      These hash bits can be used to avoid potentially expensive
40 //      key comparisons.
41 //
42 // FlatMap passes in a bucket that contains keys and values, FlatSet
43 // passes in a bucket that does not contain values.
44 template <typename Key, typename Bucket, class Hash, class Eq>
45 class FlatRep {
46  public:
47   // kWidth is the number of entries stored in a bucket.
48   static const uint32 kBase = 3;
49   static const uint32 kWidth = (1 << kBase);
50 
FlatRep(size_t N,const Hash & hf,const Eq & eq)51   FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
52     Init(N);
53   }
FlatRep(const FlatRep & src)54   FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
55     Init(src.size());
56     CopyEntries(src.array_, src.end_, CopyEntry());
57   }
58 
FlatRep(FlatRep && src)59   FlatRep(FlatRep&& src)
60       // Copy rather than move src.hash_ and src.equal_.  This is necessary to
61       // leave src in a valid state -- otherwise e.g. if hash_ is an
62       // std::function, moving it would null it out.
63       : hash_(src.hash_), equal_(src.equal_) {
64     // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap
65     // as it could be.  The fundamental problem is that we need to leave src in
66     // a valid state, and FlatRep *always* owns a nonzero amount of memory.
67     Init(1);
68     swap(src);
69   }
70 
~FlatRep()71   ~FlatRep() {
72     clear_no_resize();
73     delete[] array_;
74   }
75 
76   // Simple accessors.
size()77   size_t size() const { return not_empty_ - deleted_; }
bucket_count()78   size_t bucket_count() const { return mask_ + 1; }
start()79   Bucket* start() const { return array_; }
limit()80   Bucket* limit() const { return end_; }
hash_function()81   const Hash& hash_function() const { return hash_; }
key_eq()82   const Eq& key_eq() const { return equal_; }
83 
84   // Overwrite contents of *this with contents of src.
CopyFrom(const FlatRep & src)85   void CopyFrom(const FlatRep& src) {
86     if (this != &src) {
87       clear_no_resize();
88       delete[] array_;
89       Init(src.size());
90       CopyEntries(src.array_, src.end_, CopyEntry());
91     }
92   }
93 
MoveFrom(FlatRep && src)94   void MoveFrom(FlatRep&& src) {
95     if (this != &src) {
96       swap(src);
97     }
98   }
99 
clear_no_resize()100   void clear_no_resize() {
101     for (Bucket* b = array_; b != end_; b++) {
102       for (uint32 i = 0; i < kWidth; i++) {
103         if (b->marker[i] >= 2) {
104           b->Destroy(i);
105           b->marker[i] = kEmpty;
106         }
107       }
108     }
109     not_empty_ = 0;
110     deleted_ = 0;
111   }
112 
clear()113   void clear() {
114     clear_no_resize();
115     grow_ = 0;  // Consider shrinking in MaybeResize()
116     MaybeResize();
117   }
118 
swap(FlatRep & x)119   void swap(FlatRep& x) {
120     using std::swap;
121     swap(array_, x.array_);
122     swap(end_, x.end_);
123     swap(lglen_, x.lglen_);
124     swap(mask_, x.mask_);
125     swap(not_empty_, x.not_empty_);
126     swap(deleted_, x.deleted_);
127     swap(grow_, x.grow_);
128     swap(shrink_, x.shrink_);
129   }
130 
131   struct SearchResult {
132     bool found;
133     Bucket* b;
134     uint32 index;
135   };
136 
137   // Hash value is partitioned as follows:
138   // 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
139   // 2. Next 3 bits give index inside bucket.
140   // 3. Remaining bits give bucket number.
141 
142   // Find bucket/index for key k.
Find(const Key & k)143   SearchResult Find(const Key& k) const {
144     size_t h = hash_(k);
145     const uint32 marker = Marker(h & 0xff);
146     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
147     uint32 num_probes = 1;            // Needed for quadratic probing
148     while (true) {
149       uint32 bi = index & (kWidth - 1);
150       Bucket* b = &array_[index >> kBase];
151       const uint32 x = b->marker[bi];
152       if (x == marker && equal_(b->key(bi), k)) {
153         return {true, b, bi};
154       } else if (x == kEmpty) {
155         return {false, nullptr, 0};
156       }
157       index = NextIndex(index, num_probes);
158       num_probes++;
159     }
160   }
161 
162   // Find bucket/index for key k, creating a new one if necessary.
163   //
164   // KeyType is a template parameter so that k's type is deduced and it
165   // becomes a universal reference which allows the key initialization
166   // below to use an rvalue constructor if available.
167   template <typename KeyType>
FindOrInsert(KeyType && k)168   SearchResult FindOrInsert(KeyType&& k) {
169     size_t h = hash_(k);
170     const uint32 marker = Marker(h & 0xff);
171     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
172     uint32 num_probes = 1;            // Needed for quadratic probing
173     Bucket* del = nullptr;            // First encountered deletion for kInsert
174     uint32 di = 0;
175     while (true) {
176       uint32 bi = index & (kWidth - 1);
177       Bucket* b = &array_[index >> kBase];
178       const uint32 x = b->marker[bi];
179       if (x == marker && equal_(b->key(bi), k)) {
180         return {true, b, bi};
181       } else if (!del && x == kDeleted) {
182         // Remember deleted index to use for insertion.
183         del = b;
184         di = bi;
185       } else if (x == kEmpty) {
186         if (del) {
187           // Store in the first deleted slot we encountered
188           b = del;
189           bi = di;
190           deleted_--;  // not_empty_ does not change
191         } else {
192           not_empty_++;
193         }
194         b->marker[bi] = marker;
195         new (&b->key(bi)) Key(std::forward<KeyType>(k));
196         return {false, b, bi};
197       }
198       index = NextIndex(index, num_probes);
199       num_probes++;
200     }
201   }
202 
Erase(Bucket * b,uint32 i)203   void Erase(Bucket* b, uint32 i) {
204     b->Destroy(i);
205     b->marker[i] = kDeleted;
206     deleted_++;
207     grow_ = 0;  // Consider shrinking on next insert
208   }
209 
Prefetch(const Key & k)210   void Prefetch(const Key& k) const {
211     size_t h = hash_(k);
212     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
213     uint32 bi = index & (kWidth - 1);
214     Bucket* b = &array_[index >> kBase];
215     port::prefetch<port::PREFETCH_HINT_T0>(&b->marker[bi]);
216     port::prefetch<port::PREFETCH_HINT_T0>(&b->storage.key[bi]);
217   }
218 
MaybeResize()219   inline void MaybeResize() {
220     if (not_empty_ < grow_) {
221       return;  // Nothing to do
222     }
223     if (grow_ == 0) {
224       // Special value set by erase to cause shrink on next insert.
225       if (size() >= shrink_) {
226         // Not small enough to shrink.
227         grow_ = static_cast<size_t>(bucket_count() * 0.8);
228         if (not_empty_ < grow_) return;
229       }
230     }
231     Resize(size() + 1);
232   }
233 
Resize(size_t N)234   void Resize(size_t N) {
235     Bucket* old = array_;
236     Bucket* old_end = end_;
237     Init(N);
238     CopyEntries(old, old_end, MoveEntry());
239     delete[] old;
240   }
241 
242  private:
243   enum { kEmpty = 0, kDeleted = 1 };  // Special markers for an entry.
244 
245   Hash hash_;         // User-supplied hasher
246   Eq equal_;          // User-supplied comparator
247   uint8 lglen_;       // lg(#buckets)
248   Bucket* array_;     // array of length (1 << lglen_)
249   Bucket* end_;       // Points just past last bucket in array_
250   size_t mask_;       // (# of entries in table) - 1
251   size_t not_empty_;  // Count of entries with marker != kEmpty
252   size_t deleted_;    // Count of entries with marker == kDeleted
253   size_t grow_;       // Grow array when not_empty_ >= grow_
254   size_t shrink_;     // Shrink array when size() < shrink_
255 
256   // Avoid kEmpty and kDeleted markers when computing hash values to
257   // store in Bucket::marker[].
Marker(uint32 hb)258   static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
259 
Init(size_t N)260   void Init(size_t N) {
261     // Make enough room for N elements.
262     size_t lg = 0;  // Smallest table is just one bucket.
263     while (N >= 0.8 * ((1 << lg) * kWidth)) {
264       lg++;
265     }
266     const size_t n = (1 << lg);
267     Bucket* array = new Bucket[n];
268     for (size_t i = 0; i < n; i++) {
269       Bucket* b = &array[i];
270       memset(b->marker, kEmpty, kWidth);
271     }
272     const size_t capacity = (1 << lg) * kWidth;
273     lglen_ = lg;
274     mask_ = capacity - 1;
275     array_ = array;
276     end_ = array + n;
277     not_empty_ = 0;
278     deleted_ = 0;
279     grow_ = static_cast<size_t>(capacity * 0.8);
280     if (lg == 0) {
281       // Already down to one bucket; no more shrinking.
282       shrink_ = 0;
283     } else {
284       shrink_ = static_cast<size_t>(grow_ * 0.4);  // Must be less than 0.5
285     }
286   }
287 
288   // Used by FreshInsert when we should copy from source.
289   struct CopyEntry {
operatorCopyEntry290     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
291       dst->CopyFrom(dsti, src, srci);
292     }
293   };
294 
295   // Used by FreshInsert when we should move from source.
296   struct MoveEntry {
operatorMoveEntry297     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
298       dst->MoveFrom(dsti, src, srci);
299       src->Destroy(srci);
300       src->marker[srci] = kDeleted;
301     }
302   };
303 
304   template <typename Copier>
CopyEntries(Bucket * start,Bucket * end,Copier copier)305   void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
306     for (Bucket* b = start; b != end; b++) {
307       for (uint32 i = 0; i < kWidth; i++) {
308         if (b->marker[i] >= 2) {
309           FreshInsert(b, i, copier);
310         }
311       }
312     }
313   }
314 
315   // Create an entry for the key numbered src_index in *src and return
316   // its bucket/index.  Used for insertion into a fresh table.  We
317   // assume that there are no deletions, and k does not already exist
318   // in the table.
319   template <typename Copier>
FreshInsert(Bucket * src,uint32 src_index,Copier copier)320   void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
321     size_t h = hash_(src->key(src_index));
322     const uint32 marker = Marker(h & 0xff);
323     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
324     uint32 num_probes = 1;            // Needed for quadratic probing
325     while (true) {
326       uint32 bi = index & (kWidth - 1);
327       Bucket* b = &array_[index >> kBase];
328       const uint32 x = b->marker[bi];
329       if (x == 0) {
330         b->marker[bi] = marker;
331         not_empty_++;
332         copier(b, bi, src, src_index);
333         return;
334       }
335       index = NextIndex(index, num_probes);
336       num_probes++;
337     }
338   }
339 
NextIndex(size_t i,uint32 num_probes)340   inline size_t NextIndex(size_t i, uint32 num_probes) const {
341     // Quadratic probing.
342     return (i + num_probes) & mask_;
343   }
344 };
345 
346 }  // namespace internal
347 }  // namespace gtl
348 }  // namespace tensorflow
349 
350 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
351