1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
17 #define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
18 
19 #include <type_traits>
20 #include "tensorflow/core/lib/gtl/flatset.h"
21 
22 namespace tensorflow {
23 namespace gtl {
24 
25 // CompactPointerSet<T> is like a std::unordered_set<T> but is optimized
26 // for small sets (<= 1 element).  T must be a pointer type.
27 template <typename T>
28 class CompactPointerSet {
29  private:
30   using BigRep = FlatSet<T>;
31 
32  public:
33   using value_type = T;
34 
CompactPointerSet()35   CompactPointerSet() : rep_(0) {}
36 
~CompactPointerSet()37   ~CompactPointerSet() {
38     static_assert(
39         std::is_pointer<T>::value,
40         "CompactPointerSet<T> can only be used with T's that are pointers");
41     if (isbig()) delete big();
42   }
43 
CompactPointerSet(const CompactPointerSet & other)44   CompactPointerSet(const CompactPointerSet& other) : rep_(0) { *this = other; }
45 
46   CompactPointerSet& operator=(const CompactPointerSet& other) {
47     if (this == &other) return *this;
48     if (other.isbig()) {
49       // big => any
50       if (!isbig()) MakeBig();
51       *big() = *other.big();
52     } else if (isbig()) {
53       // !big => big
54       big()->clear();
55       if (other.rep_ != 0) {
56         big()->insert(reinterpret_cast<T>(other.rep_));
57       }
58     } else {
59       // !big => !big
60       rep_ = other.rep_;
61     }
62     return *this;
63   }
64 
65   class iterator {
66    public:
67     typedef ssize_t difference_type;
68     typedef T value_type;
69     typedef const T* pointer;
70     typedef const T& reference;
71     typedef ::std::forward_iterator_tag iterator_category;
72 
iterator(uintptr_t rep)73     explicit iterator(uintptr_t rep)
74         : bigrep_(false), single_(reinterpret_cast<T>(rep)) {}
iterator(typename BigRep::iterator iter)75     explicit iterator(typename BigRep::iterator iter)
76         : bigrep_(true), single_(nullptr), iter_(iter) {}
77 
78     iterator& operator++() {
79       if (bigrep_) {
80         ++iter_;
81       } else {
82         DCHECK(single_ != nullptr);
83         single_ = nullptr;
84       }
85       return *this;
86     }
87     // maybe post-increment?
88 
89     bool operator==(const iterator& other) const {
90       if (bigrep_) {
91         return iter_ == other.iter_;
92       } else {
93         return single_ == other.single_;
94       }
95     }
96     bool operator!=(const iterator& other) const { return !(*this == other); }
97 
98     const T& operator*() const {
99       if (bigrep_) {
100         return *iter_;
101       } else {
102         DCHECK(single_ != nullptr);
103         return single_;
104       }
105     }
106 
107    private:
108     friend class CompactPointerSet;
109     bool bigrep_;
110     T single_;
111     typename BigRep::iterator iter_;
112   };
113   using const_iterator = iterator;
114 
empty()115   bool empty() const { return isbig() ? big()->empty() : (rep_ == 0); }
size()116   size_t size() const { return isbig() ? big()->size() : (rep_ == 0 ? 0 : 1); }
117 
clear()118   void clear() {
119     if (isbig()) {
120       delete big();
121     }
122     rep_ = 0;
123   }
124 
insert(T elem)125   std::pair<iterator, bool> insert(T elem) {
126     if (!isbig()) {
127       if (rep_ == 0) {
128         uintptr_t v = reinterpret_cast<uintptr_t>(elem);
129         if (v == 0 || ((v & 0x3) != 0)) {
130           // Cannot use small representation for nullptr.  Fall through.
131         } else {
132           rep_ = v;
133           return {iterator(v), true};
134         }
135       }
136       MakeBig();
137     }
138     auto p = big()->insert(elem);
139     return {iterator(p.first), p.second};
140   }
141 
142   template <typename InputIter>
insert(InputIter begin,InputIter end)143   void insert(InputIter begin, InputIter end) {
144     for (; begin != end; ++begin) {
145       insert(*begin);
146     }
147   }
148 
begin()149   const_iterator begin() const {
150     return isbig() ? iterator(big()->begin()) : iterator(rep_);
151   }
end()152   const_iterator end() const {
153     return isbig() ? iterator(big()->end()) : iterator(0);
154   }
155 
find(T elem)156   iterator find(T elem) const {
157     if (rep_ == reinterpret_cast<uintptr_t>(elem)) {
158       return iterator(rep_);
159     } else if (!isbig()) {
160       return iterator(0);
161     } else {
162       return iterator(big()->find(elem));
163     }
164   }
165 
count(T elem)166   size_t count(T elem) const { return find(elem) != end() ? 1 : 0; }
167 
erase(T elem)168   size_t erase(T elem) {
169     if (!isbig()) {
170       if (rep_ == reinterpret_cast<uintptr_t>(elem)) {
171         rep_ = 0;
172         return 1;
173       } else {
174         return 0;
175       }
176     } else {
177       return big()->erase(elem);
178     }
179   }
180 
181  private:
182   // Size         rep_
183   // -------------------------------------------------------------------------
184   // 0            0
185   // 1            The pointer itself (bottom bits == 00)
186   // large        Pointer to a BigRep (bottom bits == 01)
187   uintptr_t rep_;
188 
isbig()189   bool isbig() const { return (rep_ & 0x3) == 1; }
big()190   BigRep* big() const {
191     DCHECK(isbig());
192     return reinterpret_cast<BigRep*>(rep_ - 1);
193   }
194 
MakeBig()195   void MakeBig() {
196     DCHECK(!isbig());
197     BigRep* big = new BigRep;
198     if (rep_ != 0) {
199       big->insert(reinterpret_cast<T>(rep_));
200     }
201     rep_ = reinterpret_cast<uintptr_t>(big) + 0x1;
202   }
203 };
204 
205 }  // namespace gtl
206 }  // namespace tensorflow
207 
208 #endif  // TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
209