1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
17 #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
18 
19 #include <stddef.h>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <utility>
24 #include "tensorflow/core/lib/gtl/flatrep.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace gtl {
31 
32 // FlatMap<K,V,...> provides a map from K to V.
33 //
34 // The map is implemented using an open-addressed hash table.  A
35 // single array holds entire map contents and collisions are resolved
36 // by probing at a sequence of locations in the array.
37 template <typename Key, typename Val, class Hash = hash<Key>,
38           class Eq = std::equal_to<Key>>
39 class FlatMap {
40  private:
41   // Forward declare some internal types needed in public section.
42   struct Bucket;
43 
44   // We cannot use std::pair<> since internal representation stores
45   // keys and values in separate arrays, so we make a custom struct
46   // that holds references to the internal key, value elements.
47   //
48   // We define the struct as private ValueType, and typedef it as public
49   // value_type, to work around a gcc bug when compiling the iterators.
50   struct ValueType {
51     typedef Key first_type;
52     typedef Val second_type;
53 
54     const Key& first;
55     Val& second;
ValueTypeValueType56     ValueType(const Key& k, Val& v) : first(k), second(v) {}
57   };
58 
59  public:
60   typedef Key key_type;
61   typedef Val mapped_type;
62   typedef Hash hasher;
63   typedef Eq key_equal;
64   typedef size_t size_type;
65   typedef ptrdiff_t difference_type;
66   typedef ValueType value_type;
67   typedef value_type* pointer;
68   typedef const value_type* const_pointer;
69   typedef value_type& reference;
70   typedef const value_type& const_reference;
71 
FlatMap()72   FlatMap() : FlatMap(1) {}
73 
74   explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
rep_(N,hf,eq)75       : rep_(N, hf, eq) {}
76 
FlatMap(const FlatMap & src)77   FlatMap(const FlatMap& src) : rep_(src.rep_) {}
78 
79   // Move constructor leaves src in a valid but unspecified state (same as
80   // std::unordered_map).
FlatMap(FlatMap && src)81   FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {}
82 
83   template <typename InputIter>
84   FlatMap(InputIter first, InputIter last, size_t N = 1,
85           const Hash& hf = Hash(), const Eq& eq = Eq())
FlatMap(N,hf,eq)86       : FlatMap(N, hf, eq) {
87     insert(first, last);
88   }
89 
90   FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1,
91           const Hash& hf = Hash(), const Eq& eq = Eq())
92       : FlatMap(init.begin(), init.end(), N, hf, eq) {}
93 
94   FlatMap& operator=(const FlatMap& src) {
95     rep_.CopyFrom(src.rep_);
96     return *this;
97   }
98 
99   // Move-assignment operator leaves src in a valid but unspecified state (same
100   // as std::unordered_map).
101   FlatMap& operator=(FlatMap&& src) {
102     rep_.MoveFrom(std::move(src.rep_));
103     return *this;
104   }
105 
~FlatMap()106   ~FlatMap() {}
107 
swap(FlatMap & x)108   void swap(FlatMap& x) { rep_.swap(x.rep_); }
clear_no_resize()109   void clear_no_resize() { rep_.clear_no_resize(); }
clear()110   void clear() { rep_.clear(); }
reserve(size_t N)111   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
rehash(size_t N)112   void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
resize(size_t N)113   void resize(size_t N) { rep_.Resize(std::max(N, size())); }
size()114   size_t size() const { return rep_.size(); }
empty()115   bool empty() const { return size() == 0; }
bucket_count()116   size_t bucket_count() const { return rep_.bucket_count(); }
hash_function()117   hasher hash_function() const { return rep_.hash_function(); }
key_eq()118   key_equal key_eq() const { return rep_.key_eq(); }
119 
120   class iterator {
121    public:
122     typedef typename FlatMap::difference_type difference_type;
123     typedef typename FlatMap::value_type value_type;
124     typedef typename FlatMap::pointer pointer;
125     typedef typename FlatMap::reference reference;
126     typedef ::std::forward_iterator_tag iterator_category;
127 
iterator()128     iterator() : b_(nullptr), end_(nullptr), i_(0) {}
129 
130     // Make iterator pointing at first element at or after b.
iterator(Bucket * b,Bucket * end)131     iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); }
132 
133     // Make iterator pointing exactly at ith element in b, which must exist.
iterator(Bucket * b,Bucket * end,uint32 i)134     iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
135       FillValue();
136     }
137 
138     reference operator*() { return *val(); }
139     pointer operator->() { return val(); }
140     bool operator==(const iterator& x) const {
141       return b_ == x.b_ && i_ == x.i_;
142     }
143     bool operator!=(const iterator& x) const { return !(*this == x); }
144     iterator& operator++() {
145       DCHECK(b_ != end_);
146       i_++;
147       SkipUnused();
148       return *this;
149     }
150     iterator operator++(int /*indicates postfix*/) {
151       iterator tmp(*this);
152       ++*this;
153       return tmp;
154     }
155 
156    private:
157     friend class FlatMap;
158     Bucket* b_;
159     Bucket* end_;
160     char space_ alignas(value_type)[sizeof(value_type)];
161     uint32 i_;
162 
val()163     pointer val() { return reinterpret_cast<pointer>(space_); }
FillValue()164     void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
SkipUnused()165     void SkipUnused() {
166       while (b_ < end_) {
167         if (i_ >= Rep::kWidth) {
168           i_ = 0;
169           b_++;
170         } else if (b_->marker[i_] < 2) {
171           i_++;
172         } else {
173           FillValue();
174           break;
175         }
176       }
177     }
178   };
179 
180   class const_iterator {
181    private:
182     mutable iterator rep_;  // Share state and logic with non-const iterator.
183    public:
184     typedef typename FlatMap::difference_type difference_type;
185     typedef typename FlatMap::value_type value_type;
186     typedef typename FlatMap::const_pointer pointer;
187     typedef typename FlatMap::const_reference reference;
188     typedef ::std::forward_iterator_tag iterator_category;
189 
const_iterator()190     const_iterator() : rep_() {}
const_iterator(Bucket * start,Bucket * end)191     const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
const_iterator(Bucket * b,Bucket * end,uint32 i)192     const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
193 
194     reference operator*() const { return *rep_.val(); }
195     pointer operator->() const { return rep_.val(); }
196     bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
197     bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
198     const_iterator& operator++() {
199       ++rep_;
200       return *this;
201     }
202     const_iterator operator++(int /*indicates postfix*/) {
203       const_iterator tmp(*this);
204       ++*this;
205       return tmp;
206     }
207   };
208 
begin()209   iterator begin() { return iterator(rep_.start(), rep_.limit()); }
end()210   iterator end() { return iterator(rep_.limit(), rep_.limit()); }
begin()211   const_iterator begin() const {
212     return const_iterator(rep_.start(), rep_.limit());
213   }
end()214   const_iterator end() const {
215     return const_iterator(rep_.limit(), rep_.limit());
216   }
217 
count(const Key & k)218   size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
find(const Key & k)219   iterator find(const Key& k) {
220     auto r = rep_.Find(k);
221     return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
222   }
find(const Key & k)223   const_iterator find(const Key& k) const {
224     auto r = rep_.Find(k);
225     return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
226   }
227 
at(const Key & k)228   Val& at(const Key& k) {
229     auto r = rep_.Find(k);
230     DCHECK(r.found);
231     return r.b->val(r.index);
232   }
at(const Key & k)233   const Val& at(const Key& k) const {
234     auto r = rep_.Find(k);
235     DCHECK(r.found);
236     return r.b->val(r.index);
237   }
238 
239   template <typename P>
insert(const P & p)240   std::pair<iterator, bool> insert(const P& p) {
241     return Insert(p.first, p.second);
242   }
insert(const std::pair<const Key,Val> & p)243   std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
244     return Insert(p.first, p.second);
245   }
246   template <typename InputIter>
insert(InputIter first,InputIter last)247   void insert(InputIter first, InputIter last) {
248     for (; first != last; ++first) {
249       insert(*first);
250     }
251   }
252 
253   Val& operator[](const Key& k) { return IndexOp(k); }
254   Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
255 
256   template <typename... Args>
emplace(Args &&...args)257   std::pair<iterator, bool> emplace(Args&&... args) {
258     return InsertPair(std::make_pair(std::forward<Args>(args)...));
259   }
260 
erase(const Key & k)261   size_t erase(const Key& k) {
262     auto r = rep_.Find(k);
263     if (!r.found) return 0;
264     rep_.Erase(r.b, r.index);
265     return 1;
266   }
erase(iterator pos)267   iterator erase(iterator pos) {
268     rep_.Erase(pos.b_, pos.i_);
269     ++pos;
270     return pos;
271   }
erase(iterator pos,iterator last)272   iterator erase(iterator pos, iterator last) {
273     for (; pos != last; ++pos) {
274       rep_.Erase(pos.b_, pos.i_);
275     }
276     return pos;
277   }
278 
equal_range(const Key & k)279   std::pair<iterator, iterator> equal_range(const Key& k) {
280     auto pos = find(k);
281     if (pos == end()) {
282       return std::make_pair(pos, pos);
283     } else {
284       auto next = pos;
285       ++next;
286       return std::make_pair(pos, next);
287     }
288   }
equal_range(const Key & k)289   std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
290     auto pos = find(k);
291     if (pos == end()) {
292       return std::make_pair(pos, pos);
293     } else {
294       auto next = pos;
295       ++next;
296       return std::make_pair(pos, next);
297     }
298   }
299 
300   bool operator==(const FlatMap& x) const {
301     if (size() != x.size()) return false;
302     for (auto& p : x) {
303       auto i = find(p.first);
304       if (i == end()) return false;
305       if (i->second != p.second) return false;
306     }
307     return true;
308   }
309   bool operator!=(const FlatMap& x) const { return !(*this == x); }
310 
311   // If key exists in the table, prefetch the associated value.  This
312   // is a hint, and may have no effect.
prefetch_value(const Key & key)313   void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
314 
315  private:
316   using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
317 
318   // Bucket stores kWidth <marker, key, value> triples.
319   // The data is organized as three parallel arrays to reduce padding.
320   struct Bucket {
321     uint8 marker[Rep::kWidth];
322 
323     // Wrap keys and values in union to control construction and destruction.
324     union Storage {
325       struct {
326         Key key[Rep::kWidth];
327         Val val[Rep::kWidth];
328       };
Storage()329       Storage() {}
~Storage()330       ~Storage() {}
331     } storage;
332 
keyBucket333     Key& key(uint32 i) {
334       DCHECK_GE(marker[i], 2);
335       return storage.key[i];
336     }
valBucket337     Val& val(uint32 i) {
338       DCHECK_GE(marker[i], 2);
339       return storage.val[i];
340     }
341     template <typename V>
InitValBucket342     void InitVal(uint32 i, V&& v) {
343       new (&storage.val[i]) Val(std::forward<V>(v));
344     }
DestroyBucket345     void Destroy(uint32 i) {
346       storage.key[i].Key::~Key();
347       storage.val[i].Val::~Val();
348     }
MoveFromBucket349     void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
350       new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
351       new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
352     }
CopyFromBucket353     void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
354       new (&storage.key[i]) Key(src->storage.key[src_index]);
355       new (&storage.val[i]) Val(src->storage.val[src_index]);
356     }
357   };
358 
359   template <typename Pair>
InsertPair(Pair && p)360   std::pair<iterator, bool> InsertPair(Pair&& p) {
361     return Insert(std::forward<decltype(p.first)>(p.first),
362                   std::forward<decltype(p.second)>(p.second));
363   }
364 
365   template <typename K, typename V>
Insert(K && k,V && v)366   std::pair<iterator, bool> Insert(K&& k, V&& v) {
367     rep_.MaybeResize();
368     auto r = rep_.FindOrInsert(std::forward<K>(k));
369     const bool inserted = !r.found;
370     if (inserted) {
371       r.b->InitVal(r.index, std::forward<V>(v));
372     }
373     return {iterator(r.b, rep_.limit(), r.index), inserted};
374   }
375 
376   template <typename K>
IndexOp(K && k)377   Val& IndexOp(K&& k) {
378     rep_.MaybeResize();
379     auto r = rep_.FindOrInsert(std::forward<K>(k));
380     Val* vptr = &r.b->val(r.index);
381     if (!r.found) {
382       new (vptr) Val();  // Initialize value in new slot.
383     }
384     return *vptr;
385   }
386 
387   Rep rep_;
388 };
389 
390 }  // namespace gtl
391 }  // namespace tensorflow
392 
393 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
394