1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 18 19 #include <string.h> 20 #include <utility> 21 #include "tensorflow/core/platform/prefetch.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 namespace gtl { 26 namespace internal { 27 28 // Internal representation for FlatMap and FlatSet. 29 // 30 // The representation is an open-addressed hash table. Conceptually, 31 // the representation is a flat array of entries. However we 32 // structure it as an array of buckets where each bucket holds 33 // kWidth entries along with metadata for the kWidth entries. The 34 // metadata marker is 35 // 36 // (a) kEmpty: the entry is empty 37 // (b) kDeleted: the entry has been deleted 38 // (c) other: the entry is occupied and has low-8 bits of its hash. 39 // These hash bits can be used to avoid potentially expensive 40 // key comparisons. 41 // 42 // FlatMap passes in a bucket that contains keys and values, FlatSet 43 // passes in a bucket that does not contain values. 44 template <typename Key, typename Bucket, class Hash, class Eq> 45 class FlatRep { 46 public: 47 // kWidth is the number of entries stored in a bucket. 48 static const uint32 kBase = 3; 49 static const uint32 kWidth = (1 << kBase); 50 FlatRep(size_t N,const Hash & hf,const Eq & eq)51 FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) { 52 Init(N); 53 } FlatRep(const FlatRep & src)54 FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) { 55 Init(src.size()); 56 CopyEntries(src.array_, src.end_, CopyEntry()); 57 } 58 FlatRep(FlatRep && src)59 FlatRep(FlatRep&& src) 60 // Copy rather than move src.hash_ and src.equal_. This is necessary to 61 // leave src in a valid state -- otherwise e.g. if hash_ is an 62 // std::function, moving it would null it out. 63 : hash_(src.hash_), equal_(src.equal_) { 64 // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap 65 // as it could be. The fundamental problem is that we need to leave src in 66 // a valid state, and FlatRep *always* owns a nonzero amount of memory. 67 Init(1); 68 swap(src); 69 } 70 ~FlatRep()71 ~FlatRep() { 72 clear_no_resize(); 73 delete[] array_; 74 } 75 76 // Simple accessors. size()77 size_t size() const { return not_empty_ - deleted_; } bucket_count()78 size_t bucket_count() const { return mask_ + 1; } start()79 Bucket* start() const { return array_; } limit()80 Bucket* limit() const { return end_; } hash_function()81 const Hash& hash_function() const { return hash_; } key_eq()82 const Eq& key_eq() const { return equal_; } 83 84 // Overwrite contents of *this with contents of src. CopyFrom(const FlatRep & src)85 void CopyFrom(const FlatRep& src) { 86 if (this != &src) { 87 clear_no_resize(); 88 delete[] array_; 89 Init(src.size()); 90 CopyEntries(src.array_, src.end_, CopyEntry()); 91 } 92 } 93 MoveFrom(FlatRep && src)94 void MoveFrom(FlatRep&& src) { 95 if (this != &src) { 96 swap(src); 97 } 98 } 99 clear_no_resize()100 void clear_no_resize() { 101 for (Bucket* b = array_; b != end_; b++) { 102 for (uint32 i = 0; i < kWidth; i++) { 103 if (b->marker[i] >= 2) { 104 b->Destroy(i); 105 b->marker[i] = kEmpty; 106 } 107 } 108 } 109 not_empty_ = 0; 110 deleted_ = 0; 111 } 112 clear()113 void clear() { 114 clear_no_resize(); 115 grow_ = 0; // Consider shrinking in MaybeResize() 116 MaybeResize(); 117 } 118 swap(FlatRep & x)119 void swap(FlatRep& x) { 120 using std::swap; 121 swap(array_, x.array_); 122 swap(end_, x.end_); 123 swap(lglen_, x.lglen_); 124 swap(mask_, x.mask_); 125 swap(not_empty_, x.not_empty_); 126 swap(deleted_, x.deleted_); 127 swap(grow_, x.grow_); 128 swap(shrink_, x.shrink_); 129 } 130 131 struct SearchResult { 132 bool found; 133 Bucket* b; 134 uint32 index; 135 }; 136 137 // Hash value is partitioned as follows: 138 // 1. Bottom 8 bits are stored in bucket to help speed up comparisons. 139 // 2. Next 3 bits give index inside bucket. 140 // 3. Remaining bits give bucket number. 141 142 // Find bucket/index for key k. Find(const Key & k)143 SearchResult Find(const Key& k) const { 144 size_t h = hash_(k); 145 const uint32 marker = Marker(h & 0xff); 146 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 147 uint32 num_probes = 1; // Needed for quadratic probing 148 while (true) { 149 uint32 bi = index & (kWidth - 1); 150 Bucket* b = &array_[index >> kBase]; 151 const uint32 x = b->marker[bi]; 152 if (x == marker && equal_(b->key(bi), k)) { 153 return {true, b, bi}; 154 } else if (x == kEmpty) { 155 return {false, nullptr, 0}; 156 } 157 index = NextIndex(index, num_probes); 158 num_probes++; 159 } 160 } 161 162 // Find bucket/index for key k, creating a new one if necessary. 163 // 164 // KeyType is a template parameter so that k's type is deduced and it 165 // becomes a universal reference which allows the key initialization 166 // below to use an rvalue constructor if available. 167 template <typename KeyType> FindOrInsert(KeyType && k)168 SearchResult FindOrInsert(KeyType&& k) { 169 size_t h = hash_(k); 170 const uint32 marker = Marker(h & 0xff); 171 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 172 uint32 num_probes = 1; // Needed for quadratic probing 173 Bucket* del = nullptr; // First encountered deletion for kInsert 174 uint32 di = 0; 175 while (true) { 176 uint32 bi = index & (kWidth - 1); 177 Bucket* b = &array_[index >> kBase]; 178 const uint32 x = b->marker[bi]; 179 if (x == marker && equal_(b->key(bi), k)) { 180 return {true, b, bi}; 181 } else if (!del && x == kDeleted) { 182 // Remember deleted index to use for insertion. 183 del = b; 184 di = bi; 185 } else if (x == kEmpty) { 186 if (del) { 187 // Store in the first deleted slot we encountered 188 b = del; 189 bi = di; 190 deleted_--; // not_empty_ does not change 191 } else { 192 not_empty_++; 193 } 194 b->marker[bi] = marker; 195 new (&b->key(bi)) Key(std::forward<KeyType>(k)); 196 return {false, b, bi}; 197 } 198 index = NextIndex(index, num_probes); 199 num_probes++; 200 } 201 } 202 Erase(Bucket * b,uint32 i)203 void Erase(Bucket* b, uint32 i) { 204 b->Destroy(i); 205 b->marker[i] = kDeleted; 206 deleted_++; 207 grow_ = 0; // Consider shrinking on next insert 208 } 209 Prefetch(const Key & k)210 void Prefetch(const Key& k) const { 211 size_t h = hash_(k); 212 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 213 uint32 bi = index & (kWidth - 1); 214 Bucket* b = &array_[index >> kBase]; 215 port::prefetch<port::PREFETCH_HINT_T0>(&b->marker[bi]); 216 port::prefetch<port::PREFETCH_HINT_T0>(&b->storage.key[bi]); 217 } 218 MaybeResize()219 inline void MaybeResize() { 220 if (not_empty_ < grow_) { 221 return; // Nothing to do 222 } 223 if (grow_ == 0) { 224 // Special value set by erase to cause shrink on next insert. 225 if (size() >= shrink_) { 226 // Not small enough to shrink. 227 grow_ = static_cast<size_t>(bucket_count() * 0.8); 228 if (not_empty_ < grow_) return; 229 } 230 } 231 Resize(size() + 1); 232 } 233 Resize(size_t N)234 void Resize(size_t N) { 235 Bucket* old = array_; 236 Bucket* old_end = end_; 237 Init(N); 238 CopyEntries(old, old_end, MoveEntry()); 239 delete[] old; 240 } 241 242 private: 243 enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry. 244 245 Hash hash_; // User-supplied hasher 246 Eq equal_; // User-supplied comparator 247 uint8 lglen_; // lg(#buckets) 248 Bucket* array_; // array of length (1 << lglen_) 249 Bucket* end_; // Points just past last bucket in array_ 250 size_t mask_; // (# of entries in table) - 1 251 size_t not_empty_; // Count of entries with marker != kEmpty 252 size_t deleted_; // Count of entries with marker == kDeleted 253 size_t grow_; // Grow array when not_empty_ >= grow_ 254 size_t shrink_; // Shrink array when size() < shrink_ 255 256 // Avoid kEmpty and kDeleted markers when computing hash values to 257 // store in Bucket::marker[]. Marker(uint32 hb)258 static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); } 259 Init(size_t N)260 void Init(size_t N) { 261 // Make enough room for N elements. 262 size_t lg = 0; // Smallest table is just one bucket. 263 while (N >= 0.8 * ((1 << lg) * kWidth)) { 264 lg++; 265 } 266 const size_t n = (1 << lg); 267 Bucket* array = new Bucket[n]; 268 for (size_t i = 0; i < n; i++) { 269 Bucket* b = &array[i]; 270 memset(b->marker, kEmpty, kWidth); 271 } 272 const size_t capacity = (1 << lg) * kWidth; 273 lglen_ = lg; 274 mask_ = capacity - 1; 275 array_ = array; 276 end_ = array + n; 277 not_empty_ = 0; 278 deleted_ = 0; 279 grow_ = static_cast<size_t>(capacity * 0.8); 280 if (lg == 0) { 281 // Already down to one bucket; no more shrinking. 282 shrink_ = 0; 283 } else { 284 shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5 285 } 286 } 287 288 // Used by FreshInsert when we should copy from source. 289 struct CopyEntry { operatorCopyEntry290 inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { 291 dst->CopyFrom(dsti, src, srci); 292 } 293 }; 294 295 // Used by FreshInsert when we should move from source. 296 struct MoveEntry { operatorMoveEntry297 inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { 298 dst->MoveFrom(dsti, src, srci); 299 src->Destroy(srci); 300 src->marker[srci] = kDeleted; 301 } 302 }; 303 304 template <typename Copier> CopyEntries(Bucket * start,Bucket * end,Copier copier)305 void CopyEntries(Bucket* start, Bucket* end, Copier copier) { 306 for (Bucket* b = start; b != end; b++) { 307 for (uint32 i = 0; i < kWidth; i++) { 308 if (b->marker[i] >= 2) { 309 FreshInsert(b, i, copier); 310 } 311 } 312 } 313 } 314 315 // Create an entry for the key numbered src_index in *src and return 316 // its bucket/index. Used for insertion into a fresh table. We 317 // assume that there are no deletions, and k does not already exist 318 // in the table. 319 template <typename Copier> FreshInsert(Bucket * src,uint32 src_index,Copier copier)320 void FreshInsert(Bucket* src, uint32 src_index, Copier copier) { 321 size_t h = hash_(src->key(src_index)); 322 const uint32 marker = Marker(h & 0xff); 323 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 324 uint32 num_probes = 1; // Needed for quadratic probing 325 while (true) { 326 uint32 bi = index & (kWidth - 1); 327 Bucket* b = &array_[index >> kBase]; 328 const uint32 x = b->marker[bi]; 329 if (x == 0) { 330 b->marker[bi] = marker; 331 not_empty_++; 332 copier(b, bi, src, src_index); 333 return; 334 } 335 index = NextIndex(index, num_probes); 336 num_probes++; 337 } 338 } 339 NextIndex(size_t i,uint32 num_probes)340 inline size_t NextIndex(size_t i, uint32 num_probes) const { 341 // Quadratic probing. 342 return (i + num_probes) & mask_; 343 } 344 }; 345 346 } // namespace internal 347 } // namespace gtl 348 } // namespace tensorflow 349 350 #endif // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 351