1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 18 19 #include <stddef.h> 20 #include <functional> 21 #include <initializer_list> 22 #include <iterator> 23 #include <utility> 24 #include "tensorflow/core/lib/gtl/flatrep.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace gtl { 31 32 // FlatMap<K,V,...> provides a map from K to V. 33 // 34 // The map is implemented using an open-addressed hash table. A 35 // single array holds entire map contents and collisions are resolved 36 // by probing at a sequence of locations in the array. 37 template <typename Key, typename Val, class Hash = hash<Key>, 38 class Eq = std::equal_to<Key>> 39 class FlatMap { 40 private: 41 // Forward declare some internal types needed in public section. 42 struct Bucket; 43 44 // We cannot use std::pair<> since internal representation stores 45 // keys and values in separate arrays, so we make a custom struct 46 // that holds references to the internal key, value elements. 47 // 48 // We define the struct as private ValueType, and typedef it as public 49 // value_type, to work around a gcc bug when compiling the iterators. 50 struct ValueType { 51 typedef Key first_type; 52 typedef Val second_type; 53 54 const Key& first; 55 Val& second; ValueTypeValueType56 ValueType(const Key& k, Val& v) : first(k), second(v) {} 57 }; 58 59 public: 60 typedef Key key_type; 61 typedef Val mapped_type; 62 typedef Hash hasher; 63 typedef Eq key_equal; 64 typedef size_t size_type; 65 typedef ptrdiff_t difference_type; 66 typedef ValueType value_type; 67 typedef value_type* pointer; 68 typedef const value_type* const_pointer; 69 typedef value_type& reference; 70 typedef const value_type& const_reference; 71 FlatMap()72 FlatMap() : FlatMap(1) {} 73 74 explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) rep_(N,hf,eq)75 : rep_(N, hf, eq) {} 76 FlatMap(const FlatMap & src)77 FlatMap(const FlatMap& src) : rep_(src.rep_) {} 78 79 // Move constructor leaves src in a valid but unspecified state (same as 80 // std::unordered_map). FlatMap(FlatMap && src)81 FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {} 82 83 template <typename InputIter> 84 FlatMap(InputIter first, InputIter last, size_t N = 1, 85 const Hash& hf = Hash(), const Eq& eq = Eq()) FlatMap(N,hf,eq)86 : FlatMap(N, hf, eq) { 87 insert(first, last); 88 } 89 90 FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1, 91 const Hash& hf = Hash(), const Eq& eq = Eq()) 92 : FlatMap(init.begin(), init.end(), N, hf, eq) {} 93 94 FlatMap& operator=(const FlatMap& src) { 95 rep_.CopyFrom(src.rep_); 96 return *this; 97 } 98 99 // Move-assignment operator leaves src in a valid but unspecified state (same 100 // as std::unordered_map). 101 FlatMap& operator=(FlatMap&& src) { 102 rep_.MoveFrom(std::move(src.rep_)); 103 return *this; 104 } 105 ~FlatMap()106 ~FlatMap() {} 107 swap(FlatMap & x)108 void swap(FlatMap& x) { rep_.swap(x.rep_); } clear_no_resize()109 void clear_no_resize() { rep_.clear_no_resize(); } clear()110 void clear() { rep_.clear(); } reserve(size_t N)111 void reserve(size_t N) { rep_.Resize(std::max(N, size())); } rehash(size_t N)112 void rehash(size_t N) { rep_.Resize(std::max(N, size())); } resize(size_t N)113 void resize(size_t N) { rep_.Resize(std::max(N, size())); } size()114 size_t size() const { return rep_.size(); } empty()115 bool empty() const { return size() == 0; } bucket_count()116 size_t bucket_count() const { return rep_.bucket_count(); } hash_function()117 hasher hash_function() const { return rep_.hash_function(); } key_eq()118 key_equal key_eq() const { return rep_.key_eq(); } 119 120 class iterator { 121 public: 122 typedef typename FlatMap::difference_type difference_type; 123 typedef typename FlatMap::value_type value_type; 124 typedef typename FlatMap::pointer pointer; 125 typedef typename FlatMap::reference reference; 126 typedef ::std::forward_iterator_tag iterator_category; 127 iterator()128 iterator() : b_(nullptr), end_(nullptr), i_(0) {} 129 130 // Make iterator pointing at first element at or after b. iterator(Bucket * b,Bucket * end)131 iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); } 132 133 // Make iterator pointing exactly at ith element in b, which must exist. iterator(Bucket * b,Bucket * end,uint32 i)134 iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) { 135 FillValue(); 136 } 137 138 reference operator*() { return *val(); } 139 pointer operator->() { return val(); } 140 bool operator==(const iterator& x) const { 141 return b_ == x.b_ && i_ == x.i_; 142 } 143 bool operator!=(const iterator& x) const { return !(*this == x); } 144 iterator& operator++() { 145 DCHECK(b_ != end_); 146 i_++; 147 SkipUnused(); 148 return *this; 149 } 150 iterator operator++(int /*indicates postfix*/) { 151 iterator tmp(*this); 152 ++*this; 153 return tmp; 154 } 155 156 private: 157 friend class FlatMap; 158 Bucket* b_; 159 Bucket* end_; 160 char space_ alignas(value_type)[sizeof(value_type)]; 161 uint32 i_; 162 val()163 pointer val() { return reinterpret_cast<pointer>(space_); } FillValue()164 void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); } SkipUnused()165 void SkipUnused() { 166 while (b_ < end_) { 167 if (i_ >= Rep::kWidth) { 168 i_ = 0; 169 b_++; 170 } else if (b_->marker[i_] < 2) { 171 i_++; 172 } else { 173 FillValue(); 174 break; 175 } 176 } 177 } 178 }; 179 180 class const_iterator { 181 private: 182 mutable iterator rep_; // Share state and logic with non-const iterator. 183 public: 184 typedef typename FlatMap::difference_type difference_type; 185 typedef typename FlatMap::value_type value_type; 186 typedef typename FlatMap::const_pointer pointer; 187 typedef typename FlatMap::const_reference reference; 188 typedef ::std::forward_iterator_tag iterator_category; 189 const_iterator()190 const_iterator() : rep_() {} const_iterator(Bucket * start,Bucket * end)191 const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} const_iterator(Bucket * b,Bucket * end,uint32 i)192 const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} 193 194 reference operator*() const { return *rep_.val(); } 195 pointer operator->() const { return rep_.val(); } 196 bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } 197 bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } 198 const_iterator& operator++() { 199 ++rep_; 200 return *this; 201 } 202 const_iterator operator++(int /*indicates postfix*/) { 203 const_iterator tmp(*this); 204 ++*this; 205 return tmp; 206 } 207 }; 208 begin()209 iterator begin() { return iterator(rep_.start(), rep_.limit()); } end()210 iterator end() { return iterator(rep_.limit(), rep_.limit()); } begin()211 const_iterator begin() const { 212 return const_iterator(rep_.start(), rep_.limit()); 213 } end()214 const_iterator end() const { 215 return const_iterator(rep_.limit(), rep_.limit()); 216 } 217 count(const Key & k)218 size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } find(const Key & k)219 iterator find(const Key& k) { 220 auto r = rep_.Find(k); 221 return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); 222 } find(const Key & k)223 const_iterator find(const Key& k) const { 224 auto r = rep_.Find(k); 225 return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); 226 } 227 at(const Key & k)228 Val& at(const Key& k) { 229 auto r = rep_.Find(k); 230 DCHECK(r.found); 231 return r.b->val(r.index); 232 } at(const Key & k)233 const Val& at(const Key& k) const { 234 auto r = rep_.Find(k); 235 DCHECK(r.found); 236 return r.b->val(r.index); 237 } 238 239 template <typename P> insert(const P & p)240 std::pair<iterator, bool> insert(const P& p) { 241 return Insert(p.first, p.second); 242 } insert(const std::pair<const Key,Val> & p)243 std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) { 244 return Insert(p.first, p.second); 245 } 246 template <typename InputIter> insert(InputIter first,InputIter last)247 void insert(InputIter first, InputIter last) { 248 for (; first != last; ++first) { 249 insert(*first); 250 } 251 } 252 253 Val& operator[](const Key& k) { return IndexOp(k); } 254 Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); } 255 256 template <typename... Args> emplace(Args &&...args)257 std::pair<iterator, bool> emplace(Args&&... args) { 258 return InsertPair(std::make_pair(std::forward<Args>(args)...)); 259 } 260 erase(const Key & k)261 size_t erase(const Key& k) { 262 auto r = rep_.Find(k); 263 if (!r.found) return 0; 264 rep_.Erase(r.b, r.index); 265 return 1; 266 } erase(iterator pos)267 iterator erase(iterator pos) { 268 rep_.Erase(pos.b_, pos.i_); 269 ++pos; 270 return pos; 271 } erase(iterator pos,iterator last)272 iterator erase(iterator pos, iterator last) { 273 for (; pos != last; ++pos) { 274 rep_.Erase(pos.b_, pos.i_); 275 } 276 return pos; 277 } 278 equal_range(const Key & k)279 std::pair<iterator, iterator> equal_range(const Key& k) { 280 auto pos = find(k); 281 if (pos == end()) { 282 return std::make_pair(pos, pos); 283 } else { 284 auto next = pos; 285 ++next; 286 return std::make_pair(pos, next); 287 } 288 } equal_range(const Key & k)289 std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { 290 auto pos = find(k); 291 if (pos == end()) { 292 return std::make_pair(pos, pos); 293 } else { 294 auto next = pos; 295 ++next; 296 return std::make_pair(pos, next); 297 } 298 } 299 300 bool operator==(const FlatMap& x) const { 301 if (size() != x.size()) return false; 302 for (auto& p : x) { 303 auto i = find(p.first); 304 if (i == end()) return false; 305 if (i->second != p.second) return false; 306 } 307 return true; 308 } 309 bool operator!=(const FlatMap& x) const { return !(*this == x); } 310 311 // If key exists in the table, prefetch the associated value. This 312 // is a hint, and may have no effect. prefetch_value(const Key & key)313 void prefetch_value(const Key& key) const { rep_.Prefetch(key); } 314 315 private: 316 using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; 317 318 // Bucket stores kWidth <marker, key, value> triples. 319 // The data is organized as three parallel arrays to reduce padding. 320 struct Bucket { 321 uint8 marker[Rep::kWidth]; 322 323 // Wrap keys and values in union to control construction and destruction. 324 union Storage { 325 struct { 326 Key key[Rep::kWidth]; 327 Val val[Rep::kWidth]; 328 }; Storage()329 Storage() {} ~Storage()330 ~Storage() {} 331 } storage; 332 keyBucket333 Key& key(uint32 i) { 334 DCHECK_GE(marker[i], 2); 335 return storage.key[i]; 336 } valBucket337 Val& val(uint32 i) { 338 DCHECK_GE(marker[i], 2); 339 return storage.val[i]; 340 } 341 template <typename V> InitValBucket342 void InitVal(uint32 i, V&& v) { 343 new (&storage.val[i]) Val(std::forward<V>(v)); 344 } DestroyBucket345 void Destroy(uint32 i) { 346 storage.key[i].Key::~Key(); 347 storage.val[i].Val::~Val(); 348 } MoveFromBucket349 void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { 350 new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); 351 new (&storage.val[i]) Val(std::move(src->storage.val[src_index])); 352 } CopyFromBucket353 void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { 354 new (&storage.key[i]) Key(src->storage.key[src_index]); 355 new (&storage.val[i]) Val(src->storage.val[src_index]); 356 } 357 }; 358 359 template <typename Pair> InsertPair(Pair && p)360 std::pair<iterator, bool> InsertPair(Pair&& p) { 361 return Insert(std::forward<decltype(p.first)>(p.first), 362 std::forward<decltype(p.second)>(p.second)); 363 } 364 365 template <typename K, typename V> Insert(K && k,V && v)366 std::pair<iterator, bool> Insert(K&& k, V&& v) { 367 rep_.MaybeResize(); 368 auto r = rep_.FindOrInsert(std::forward<K>(k)); 369 const bool inserted = !r.found; 370 if (inserted) { 371 r.b->InitVal(r.index, std::forward<V>(v)); 372 } 373 return {iterator(r.b, rep_.limit(), r.index), inserted}; 374 } 375 376 template <typename K> IndexOp(K && k)377 Val& IndexOp(K&& k) { 378 rep_.MaybeResize(); 379 auto r = rep_.FindOrInsert(std::forward<K>(k)); 380 Val* vptr = &r.b->val(r.index); 381 if (!r.found) { 382 new (vptr) Val(); // Initialize value in new slot. 383 } 384 return *vptr; 385 } 386 387 Rep rep_; 388 }; 389 390 } // namespace gtl 391 } // namespace tensorflow 392 393 #endif // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ 394