1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
17 #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
18 
19 #include <stddef.h>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <utility>
24 #include "tensorflow/core/lib/gtl/flatrep.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace gtl {
31 
32 // FlatSet<K,...> provides a set of K.
33 //
34 // The map is implemented using an open-addressed hash table.  A
35 // single array holds entire map contents and collisions are resolved
36 // by probing at a sequence of locations in the array.
37 template <typename Key, class Hash = hash<Key>, class Eq = std::equal_to<Key>>
38 class FlatSet {
39  private:
40   // Forward declare some internal types needed in public section.
41   struct Bucket;
42 
43  public:
44   typedef Key key_type;
45   typedef Key value_type;
46   typedef Hash hasher;
47   typedef Eq key_equal;
48   typedef size_t size_type;
49   typedef ptrdiff_t difference_type;
50   typedef value_type* pointer;
51   typedef const value_type* const_pointer;
52   typedef value_type& reference;
53   typedef const value_type& const_reference;
54 
FlatSet()55   FlatSet() : FlatSet(1) {}
56 
57   explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
rep_(N,hf,eq)58       : rep_(N, hf, eq) {}
59 
FlatSet(const FlatSet & src)60   FlatSet(const FlatSet& src) : rep_(src.rep_) {}
61 
62   // Move constructor leaves src in a valid but unspecified state (same as
63   // std::unordered_set).
FlatSet(FlatSet && src)64   FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {}
65 
66   template <typename InputIter>
67   FlatSet(InputIter first, InputIter last, size_t N = 1,
68           const Hash& hf = Hash(), const Eq& eq = Eq())
FlatSet(N,hf,eq)69       : FlatSet(N, hf, eq) {
70     insert(first, last);
71   }
72 
73   FlatSet(std::initializer_list<value_type> init, size_t N = 1,
74           const Hash& hf = Hash(), const Eq& eq = Eq())
75       : FlatSet(init.begin(), init.end(), N, hf, eq) {}
76 
77   FlatSet& operator=(const FlatSet& src) {
78     rep_.CopyFrom(src.rep_);
79     return *this;
80   }
81 
82   // Move-assignment operator leaves src in a valid but unspecified state (same
83   // as std::unordered_set).
84   FlatSet& operator=(FlatSet&& src) {
85     rep_.MoveFrom(std::move(src.rep_));
86     return *this;
87   }
88 
~FlatSet()89   ~FlatSet() {}
90 
swap(FlatSet & x)91   void swap(FlatSet& x) { rep_.swap(x.rep_); }
clear_no_resize()92   void clear_no_resize() { rep_.clear_no_resize(); }
clear()93   void clear() { rep_.clear(); }
reserve(size_t N)94   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
rehash(size_t N)95   void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
resize(size_t N)96   void resize(size_t N) { rep_.Resize(std::max(N, size())); }
size()97   size_t size() const { return rep_.size(); }
empty()98   bool empty() const { return size() == 0; }
bucket_count()99   size_t bucket_count() const { return rep_.bucket_count(); }
hash_function()100   hasher hash_function() const { return rep_.hash_function(); }
key_eq()101   key_equal key_eq() const { return rep_.key_eq(); }
102 
103   class const_iterator {
104    public:
105     typedef typename FlatSet::difference_type difference_type;
106     typedef typename FlatSet::value_type value_type;
107     typedef typename FlatSet::const_pointer pointer;
108     typedef typename FlatSet::const_reference reference;
109     typedef ::std::forward_iterator_tag iterator_category;
110 
const_iterator()111     const_iterator() : b_(nullptr), end_(nullptr), i_(0) {}
112 
113     // Make iterator pointing at first element at or after b.
const_iterator(Bucket * b,Bucket * end)114     const_iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
115       SkipUnused();
116     }
117 
118     // Make iterator pointing exactly at ith element in b, which must exist.
const_iterator(Bucket * b,Bucket * end,uint32 i)119     const_iterator(Bucket* b, Bucket* end, uint32 i)
120         : b_(b), end_(end), i_(i) {}
121 
122     reference operator*() const { return key(); }
123     pointer operator->() const { return &key(); }
124     bool operator==(const const_iterator& x) const {
125       return b_ == x.b_ && i_ == x.i_;
126     }
127     bool operator!=(const const_iterator& x) const { return !(*this == x); }
128     const_iterator& operator++() {
129       DCHECK(b_ != end_);
130       i_++;
131       SkipUnused();
132       return *this;
133     }
134     const_iterator operator++(int /*indicates postfix*/) {
135       const_iterator tmp(*this);
136       ++*this;
137       return tmp;
138     }
139 
140    private:
141     friend class FlatSet;
142     Bucket* b_;
143     Bucket* end_;
144     uint32 i_;
145 
key()146     reference key() const { return b_->key(i_); }
SkipUnused()147     void SkipUnused() {
148       while (b_ < end_) {
149         if (i_ >= Rep::kWidth) {
150           i_ = 0;
151           b_++;
152         } else if (b_->marker[i_] < 2) {
153           i_++;
154         } else {
155           break;
156         }
157       }
158     }
159   };
160 
161   typedef const_iterator iterator;
162 
begin()163   iterator begin() { return iterator(rep_.start(), rep_.limit()); }
end()164   iterator end() { return iterator(rep_.limit(), rep_.limit()); }
begin()165   const_iterator begin() const {
166     return const_iterator(rep_.start(), rep_.limit());
167   }
end()168   const_iterator end() const {
169     return const_iterator(rep_.limit(), rep_.limit());
170   }
171 
count(const Key & k)172   size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
find(const Key & k)173   iterator find(const Key& k) {
174     auto r = rep_.Find(k);
175     return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
176   }
find(const Key & k)177   const_iterator find(const Key& k) const {
178     auto r = rep_.Find(k);
179     return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
180   }
181 
insert(const Key & k)182   std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
insert(Key && k)183   std::pair<iterator, bool> insert(Key&& k) { return Insert(std::move(k)); }
184   template <typename InputIter>
insert(InputIter first,InputIter last)185   void insert(InputIter first, InputIter last) {
186     for (; first != last; ++first) {
187       insert(*first);
188     }
189   }
190 
191   template <typename... Args>
emplace(Args &&...args)192   std::pair<iterator, bool> emplace(Args&&... args) {
193     rep_.MaybeResize();
194     auto r = rep_.FindOrInsert(std::forward<Args>(args)...);
195     const bool inserted = !r.found;
196     return {iterator(r.b, rep_.limit(), r.index), inserted};
197   }
198 
erase(const Key & k)199   size_t erase(const Key& k) {
200     auto r = rep_.Find(k);
201     if (!r.found) return 0;
202     rep_.Erase(r.b, r.index);
203     return 1;
204   }
erase(iterator pos)205   iterator erase(iterator pos) {
206     rep_.Erase(pos.b_, pos.i_);
207     ++pos;
208     return pos;
209   }
erase(iterator pos,iterator last)210   iterator erase(iterator pos, iterator last) {
211     for (; pos != last; ++pos) {
212       rep_.Erase(pos.b_, pos.i_);
213     }
214     return pos;
215   }
216 
equal_range(const Key & k)217   std::pair<iterator, iterator> equal_range(const Key& k) {
218     auto pos = find(k);
219     if (pos == end()) {
220       return std::make_pair(pos, pos);
221     } else {
222       auto next = pos;
223       ++next;
224       return std::make_pair(pos, next);
225     }
226   }
equal_range(const Key & k)227   std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
228     auto pos = find(k);
229     if (pos == end()) {
230       return std::make_pair(pos, pos);
231     } else {
232       auto next = pos;
233       ++next;
234       return std::make_pair(pos, next);
235     }
236   }
237 
238   bool operator==(const FlatSet& x) const {
239     if (size() != x.size()) return false;
240     for (const auto& elem : x) {
241       auto i = find(elem);
242       if (i == end()) return false;
243     }
244     return true;
245   }
246   bool operator!=(const FlatSet& x) const { return !(*this == x); }
247 
248   // If key exists in the table, prefetch it.  This is a hint, and may
249   // have no effect.
prefetch_value(const Key & key)250   void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
251 
252  private:
253   using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
254 
255   // Bucket stores kWidth <marker, key, value> triples.
256   // The data is organized as three parallel arrays to reduce padding.
257   struct Bucket {
258     uint8 marker[Rep::kWidth];
259 
260     // Wrap keys in union to control construction and destruction.
261     union Storage {
262       Key key[Rep::kWidth];
Storage()263       Storage() {}
~Storage()264       ~Storage() {}
265     } storage;
266 
keyBucket267     Key& key(uint32 i) {
268       DCHECK_GE(marker[i], 2);
269       return storage.key[i];
270     }
DestroyBucket271     void Destroy(uint32 i) { storage.key[i].Key::~Key(); }
MoveFromBucket272     void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
273       new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
274     }
CopyFromBucket275     void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
276       new (&storage.key[i]) Key(src->storage.key[src_index]);
277     }
278   };
279 
280   template <typename K>
Insert(K && k)281   std::pair<iterator, bool> Insert(K&& k) {
282     rep_.MaybeResize();
283     auto r = rep_.FindOrInsert(std::forward<K>(k));
284     const bool inserted = !r.found;
285     return {iterator(r.b, rep_.limit(), r.index), inserted};
286   }
287 
288   Rep rep_;
289 };
290 
291 }  // namespace gtl
292 }  // namespace tensorflow
293 
294 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
295