1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ART_RUNTIME_BASE_HASH_SET_H_
18 #define ART_RUNTIME_BASE_HASH_SET_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <stdint.h>
23 #include <utility>
24 
25 #include "bit_utils.h"
26 #include "logging.h"
27 
28 namespace art {
29 
30 // Returns true if an item is empty.
31 template <class T>
32 class DefaultEmptyFn {
33  public:
MakeEmpty(T & item)34   void MakeEmpty(T& item) const {
35     item = T();
36   }
IsEmpty(const T & item)37   bool IsEmpty(const T& item) const {
38     return item == T();
39   }
40 };
41 
42 template <class T>
43 class DefaultEmptyFn<T*> {
44  public:
MakeEmpty(T * & item)45   void MakeEmpty(T*& item) const {
46     item = nullptr;
47   }
IsEmpty(const T * & item)48   bool IsEmpty(const T*& item) const {
49     return item == nullptr;
50   }
51 };
52 
53 // Low memory version of a hash set, uses less memory than std::unordered_set since elements aren't
54 // boxed. Uses linear probing to resolve collisions.
55 // EmptyFn needs to implement two functions MakeEmpty(T& item) and IsEmpty(const T& item).
56 // TODO: We could get rid of this requirement by using a bitmap, though maybe this would be slower
57 // and more complicated.
58 template <class T, class EmptyFn = DefaultEmptyFn<T>, class HashFn = std::hash<T>,
59     class Pred = std::equal_to<T>, class Alloc = std::allocator<T>>
60 class HashSet {
61   template <class Elem, class HashSetType>
62   class BaseIterator {
63    public:
64     BaseIterator(const BaseIterator&) = default;
65     BaseIterator(BaseIterator&&) = default;
BaseIterator(HashSetType * hash_set,size_t index)66     BaseIterator(HashSetType* hash_set, size_t index) : index_(index), hash_set_(hash_set) {
67     }
68     BaseIterator& operator=(const BaseIterator&) = default;
69     BaseIterator& operator=(BaseIterator&&) = default;
70 
71     bool operator==(const BaseIterator& other) const {
72       return hash_set_ == other.hash_set_ && this->index_ == other.index_;
73     }
74 
75     bool operator!=(const BaseIterator& other) const {
76       return !(*this == other);
77     }
78 
79     BaseIterator operator++() {  // Value after modification.
80       this->index_ = this->NextNonEmptySlot(this->index_, hash_set_);
81       return *this;
82     }
83 
84     BaseIterator operator++(int) {
85       Iterator temp = *this;
86       this->index_ = this->NextNonEmptySlot(this->index_, hash_set_);
87       return temp;
88     }
89 
90     Elem& operator*() const {
91       DCHECK(!hash_set_->IsFreeSlot(this->index_));
92       return hash_set_->ElementForIndex(this->index_);
93     }
94 
95     Elem* operator->() const {
96       return &**this;
97     }
98 
99     // TODO: Operator -- --(int)
100 
101    private:
102     size_t index_;
103     HashSetType* hash_set_;
104 
NextNonEmptySlot(size_t index,const HashSet * hash_set)105     size_t NextNonEmptySlot(size_t index, const HashSet* hash_set) const {
106       const size_t num_buckets = hash_set->NumBuckets();
107       DCHECK_LT(index, num_buckets);
108       do {
109         ++index;
110       } while (index < num_buckets && hash_set->IsFreeSlot(index));
111       return index;
112     }
113 
114     friend class HashSet;
115   };
116 
117  public:
118   static constexpr double kDefaultMinLoadFactor = 0.5;
119   static constexpr double kDefaultMaxLoadFactor = 0.9;
120   static constexpr size_t kMinBuckets = 1000;
121 
122   typedef BaseIterator<T, HashSet> Iterator;
123   typedef BaseIterator<const T, const HashSet> ConstIterator;
124 
125   // If we don't own the data, this will create a new array which owns the data.
Clear()126   void Clear() {
127     DeallocateStorage();
128     AllocateStorage(1);
129     num_elements_ = 0;
130     elements_until_expand_ = 0;
131   }
132 
HashSet()133   HashSet() : num_elements_(0), num_buckets_(0), owns_data_(false), data_(nullptr),
134       min_load_factor_(kDefaultMinLoadFactor), max_load_factor_(kDefaultMaxLoadFactor) {
135     Clear();
136   }
137 
HashSet(const HashSet & other)138   HashSet(const HashSet& other) : num_elements_(0), num_buckets_(0), owns_data_(false),
139       data_(nullptr) {
140     *this = other;
141   }
142 
HashSet(HashSet && other)143   HashSet(HashSet&& other) : num_elements_(0), num_buckets_(0), owns_data_(false),
144       data_(nullptr) {
145     *this = std::move(other);
146   }
147 
148   // Construct from existing data.
149   // Read from a block of memory, if make_copy_of_data is false, then data_ points to within the
150   // passed in ptr_.
HashSet(const uint8_t * ptr,bool make_copy_of_data,size_t * read_count)151   HashSet(const uint8_t* ptr, bool make_copy_of_data, size_t* read_count) {
152     uint64_t temp;
153     size_t offset = 0;
154     offset = ReadFromBytes(ptr, offset, &temp);
155     num_elements_ = static_cast<uint64_t>(temp);
156     offset = ReadFromBytes(ptr, offset, &temp);
157     num_buckets_ = static_cast<uint64_t>(temp);
158     CHECK_LE(num_elements_, num_buckets_);
159     offset = ReadFromBytes(ptr, offset, &temp);
160     elements_until_expand_ = static_cast<uint64_t>(temp);
161     offset = ReadFromBytes(ptr, offset, &min_load_factor_);
162     offset = ReadFromBytes(ptr, offset, &max_load_factor_);
163     if (!make_copy_of_data) {
164       owns_data_ = false;
165       data_ = const_cast<T*>(reinterpret_cast<const T*>(ptr + offset));
166       offset += sizeof(*data_) * num_buckets_;
167     } else {
168       AllocateStorage(num_buckets_);
169       // Write elements, not that this may not be safe for cross compilation if the elements are
170       // pointer sized.
171       for (size_t i = 0; i < num_buckets_; ++i) {
172         offset = ReadFromBytes(ptr, offset, &data_[i]);
173       }
174     }
175     // Caller responsible for aligning.
176     *read_count = offset;
177   }
178 
179   // Returns how large the table is after being written. If target is null, then no writing happens
180   // but the size is still returned. Target must be 8 byte aligned.
WriteToMemory(uint8_t * ptr)181   size_t WriteToMemory(uint8_t* ptr) {
182     size_t offset = 0;
183     offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(num_elements_));
184     offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(num_buckets_));
185     offset = WriteToBytes(ptr, offset, static_cast<uint64_t>(elements_until_expand_));
186     offset = WriteToBytes(ptr, offset, min_load_factor_);
187     offset = WriteToBytes(ptr, offset, max_load_factor_);
188     // Write elements, not that this may not be safe for cross compilation if the elements are
189     // pointer sized.
190     for (size_t i = 0; i < num_buckets_; ++i) {
191       offset = WriteToBytes(ptr, offset, data_[i]);
192     }
193     // Caller responsible for aligning.
194     return offset;
195   }
196 
~HashSet()197   ~HashSet() {
198     DeallocateStorage();
199   }
200 
201   HashSet& operator=(HashSet&& other) {
202     std::swap(data_, other.data_);
203     std::swap(num_buckets_, other.num_buckets_);
204     std::swap(num_elements_, other.num_elements_);
205     std::swap(elements_until_expand_, other.elements_until_expand_);
206     std::swap(min_load_factor_, other.min_load_factor_);
207     std::swap(max_load_factor_, other.max_load_factor_);
208     std::swap(owns_data_, other.owns_data_);
209     return *this;
210   }
211 
212   HashSet& operator=(const HashSet& other) {
213     DeallocateStorage();
214     AllocateStorage(other.NumBuckets());
215     for (size_t i = 0; i < num_buckets_; ++i) {
216       ElementForIndex(i) = other.data_[i];
217     }
218     num_elements_ = other.num_elements_;
219     elements_until_expand_ = other.elements_until_expand_;
220     min_load_factor_ = other.min_load_factor_;
221     max_load_factor_ = other.max_load_factor_;
222     return *this;
223   }
224 
225   // Lower case for c++11 for each.
begin()226   Iterator begin() {
227     Iterator ret(this, 0);
228     if (num_buckets_ != 0 && IsFreeSlot(ret.index_)) {
229       ++ret;  // Skip all the empty slots.
230     }
231     return ret;
232   }
233 
234   // Lower case for c++11 for each.
end()235   Iterator end() {
236     return Iterator(this, NumBuckets());
237   }
238 
Empty()239   bool Empty() {
240     return Size() == 0;
241   }
242 
243   // Erase algorithm:
244   // Make an empty slot where the iterator is pointing.
245   // Scan fowards until we hit another empty slot.
246   // If an element inbetween doesn't rehash to the range from the current empty slot to the
247   // iterator. It must be before the empty slot, in that case we can move it to the empty slot
248   // and set the empty slot to be the location we just moved from.
249   // Relies on maintaining the invariant that there's no empty slots from the 'ideal' index of an
250   // element to its actual location/index.
Erase(Iterator it)251   Iterator Erase(Iterator it) {
252     // empty_index is the index that will become empty.
253     size_t empty_index = it.index_;
254     DCHECK(!IsFreeSlot(empty_index));
255     size_t next_index = empty_index;
256     bool filled = false;  // True if we filled the empty index.
257     while (true) {
258       next_index = NextIndex(next_index);
259       T& next_element = ElementForIndex(next_index);
260       // If the next element is empty, we are done. Make sure to clear the current empty index.
261       if (emptyfn_.IsEmpty(next_element)) {
262         emptyfn_.MakeEmpty(ElementForIndex(empty_index));
263         break;
264       }
265       // Otherwise try to see if the next element can fill the current empty index.
266       const size_t next_hash = hashfn_(next_element);
267       // Calculate the ideal index, if it is within empty_index + 1 to next_index then there is
268       // nothing we can do.
269       size_t next_ideal_index = IndexForHash(next_hash);
270       // Loop around if needed for our check.
271       size_t unwrapped_next_index = next_index;
272       if (unwrapped_next_index < empty_index) {
273         unwrapped_next_index += NumBuckets();
274       }
275       // Loop around if needed for our check.
276       size_t unwrapped_next_ideal_index = next_ideal_index;
277       if (unwrapped_next_ideal_index < empty_index) {
278         unwrapped_next_ideal_index += NumBuckets();
279       }
280       if (unwrapped_next_ideal_index <= empty_index ||
281           unwrapped_next_ideal_index > unwrapped_next_index) {
282         // If the target index isn't within our current range it must have been probed from before
283         // the empty index.
284         ElementForIndex(empty_index) = std::move(next_element);
285         filled = true;  // TODO: Optimize
286         empty_index = next_index;
287       }
288     }
289     --num_elements_;
290     // If we didn't fill the slot then we need go to the next non free slot.
291     if (!filled) {
292       ++it;
293     }
294     return it;
295   }
296 
297   // Find an element, returns end() if not found.
298   // Allows custom key (K) types, example of when this is useful:
299   // Set of Class* sorted by name, want to find a class with a name but can't allocate a dummy
300   // object in the heap for performance solution.
301   template <typename K>
Find(const K & element)302   Iterator Find(const K& element) {
303     return FindWithHash(element, hashfn_(element));
304   }
305 
306   template <typename K>
Find(const K & element)307   ConstIterator Find(const K& element) const {
308     return FindWithHash(element, hashfn_(element));
309   }
310 
311   template <typename K>
FindWithHash(const K & element,size_t hash)312   Iterator FindWithHash(const K& element, size_t hash) {
313     return Iterator(this, FindIndex(element, hash));
314   }
315 
316   template <typename K>
FindWithHash(const K & element,size_t hash)317   ConstIterator FindWithHash(const K& element, size_t hash) const {
318     return ConstIterator(this, FindIndex(element, hash));
319   }
320 
321   // Insert an element, allows duplicates.
Insert(const T & element)322   void Insert(const T& element) {
323     InsertWithHash(element, hashfn_(element));
324   }
325 
InsertWithHash(const T & element,size_t hash)326   void InsertWithHash(const T& element, size_t hash) {
327     DCHECK_EQ(hash, hashfn_(element));
328     if (num_elements_ >= elements_until_expand_) {
329       Expand();
330       DCHECK_LT(num_elements_, elements_until_expand_);
331     }
332     const size_t index = FirstAvailableSlot(IndexForHash(hash));
333     data_[index] = element;
334     ++num_elements_;
335   }
336 
Size()337   size_t Size() const {
338     return num_elements_;
339   }
340 
ShrinkToMaximumLoad()341   void ShrinkToMaximumLoad() {
342     Resize(Size() / max_load_factor_);
343   }
344 
345   // To distance that inserted elements were probed. Used for measuring how good hash functions
346   // are.
TotalProbeDistance()347   size_t TotalProbeDistance() const {
348     size_t total = 0;
349     for (size_t i = 0; i < NumBuckets(); ++i) {
350       const T& element = ElementForIndex(i);
351       if (!emptyfn_.IsEmpty(element)) {
352         size_t ideal_location = IndexForHash(hashfn_(element));
353         if (ideal_location > i) {
354           total += i + NumBuckets() - ideal_location;
355         } else {
356           total += i - ideal_location;
357         }
358       }
359     }
360     return total;
361   }
362 
363   // Calculate the current load factor and return it.
CalculateLoadFactor()364   double CalculateLoadFactor() const {
365     return static_cast<double>(Size()) / static_cast<double>(NumBuckets());
366   }
367 
368   // Make sure that everything reinserts in the right spot. Returns the number of errors.
Verify()369   size_t Verify() {
370     size_t errors = 0;
371     for (size_t i = 0; i < num_buckets_; ++i) {
372       T& element = data_[i];
373       if (!emptyfn_.IsEmpty(element)) {
374         T temp;
375         emptyfn_.MakeEmpty(temp);
376         std::swap(temp, element);
377         size_t first_slot = FirstAvailableSlot(IndexForHash(hashfn_(temp)));
378         if (i != first_slot) {
379           LOG(ERROR) << "Element " << i << " should be in slot " << first_slot;
380           ++errors;
381         }
382         std::swap(temp, element);
383       }
384     }
385     return errors;
386   }
387 
388  private:
ElementForIndex(size_t index)389   T& ElementForIndex(size_t index) {
390     DCHECK_LT(index, NumBuckets());
391     DCHECK(data_ != nullptr);
392     return data_[index];
393   }
394 
ElementForIndex(size_t index)395   const T& ElementForIndex(size_t index) const {
396     DCHECK_LT(index, NumBuckets());
397     DCHECK(data_ != nullptr);
398     return data_[index];
399   }
400 
IndexForHash(size_t hash)401   size_t IndexForHash(size_t hash) const {
402     return hash % num_buckets_;
403   }
404 
NextIndex(size_t index)405   size_t NextIndex(size_t index) const {
406     if (UNLIKELY(++index >= num_buckets_)) {
407       DCHECK_EQ(index, NumBuckets());
408       return 0;
409     }
410     return index;
411   }
412 
413   // Find the hash table slot for an element, or return NumBuckets() if not found.
414   // This value for not found is important so that Iterator(this, FindIndex(...)) == end().
415   template <typename K>
FindIndex(const K & element,size_t hash)416   size_t FindIndex(const K& element, size_t hash) const {
417     DCHECK_EQ(hashfn_(element), hash);
418     size_t index = IndexForHash(hash);
419     while (true) {
420       const T& slot = ElementForIndex(index);
421       if (emptyfn_.IsEmpty(slot)) {
422         return NumBuckets();
423       }
424       if (pred_(slot, element)) {
425         return index;
426       }
427       index = NextIndex(index);
428     }
429   }
430 
IsFreeSlot(size_t index)431   bool IsFreeSlot(size_t index) const {
432     return emptyfn_.IsEmpty(ElementForIndex(index));
433   }
434 
NumBuckets()435   size_t NumBuckets() const {
436     return num_buckets_;
437   }
438 
439   // Allocate a number of buckets.
AllocateStorage(size_t num_buckets)440   void AllocateStorage(size_t num_buckets) {
441     num_buckets_ = num_buckets;
442     data_ = allocfn_.allocate(num_buckets_);
443     owns_data_ = true;
444     for (size_t i = 0; i < num_buckets_; ++i) {
445       allocfn_.construct(allocfn_.address(data_[i]));
446       emptyfn_.MakeEmpty(data_[i]);
447     }
448   }
449 
DeallocateStorage()450   void DeallocateStorage() {
451     if (num_buckets_ != 0) {
452       if (owns_data_) {
453         for (size_t i = 0; i < NumBuckets(); ++i) {
454           allocfn_.destroy(allocfn_.address(data_[i]));
455         }
456         allocfn_.deallocate(data_, NumBuckets());
457         owns_data_ = false;
458       }
459       data_ = nullptr;
460       num_buckets_ = 0;
461     }
462   }
463 
464   // Expand the set based on the load factors.
Expand()465   void Expand() {
466     size_t min_index = static_cast<size_t>(Size() / min_load_factor_);
467     if (min_index < kMinBuckets) {
468       min_index = kMinBuckets;
469     }
470     // Resize based on the minimum load factor.
471     Resize(min_index);
472     // When we hit elements_until_expand_, we are at the max load factor and must expand again.
473     elements_until_expand_ = NumBuckets() * max_load_factor_;
474   }
475 
476   // Expand / shrink the table to the new specified size.
Resize(size_t new_size)477   void Resize(size_t new_size) {
478     DCHECK_GE(new_size, Size());
479     T* const old_data = data_;
480     size_t old_num_buckets = num_buckets_;
481     // Reinsert all of the old elements.
482     const bool owned_data = owns_data_;
483     AllocateStorage(new_size);
484     for (size_t i = 0; i < old_num_buckets; ++i) {
485       T& element = old_data[i];
486       if (!emptyfn_.IsEmpty(element)) {
487         data_[FirstAvailableSlot(IndexForHash(hashfn_(element)))] = std::move(element);
488       }
489       if (owned_data) {
490         allocfn_.destroy(allocfn_.address(element));
491       }
492     }
493     if (owned_data) {
494       allocfn_.deallocate(old_data, old_num_buckets);
495     }
496   }
497 
FirstAvailableSlot(size_t index)498   ALWAYS_INLINE size_t FirstAvailableSlot(size_t index) const {
499     while (!emptyfn_.IsEmpty(data_[index])) {
500       index = NextIndex(index);
501     }
502     return index;
503   }
504 
505   // Return new offset.
506   template <typename Elem>
WriteToBytes(uint8_t * ptr,size_t offset,Elem n)507   static size_t WriteToBytes(uint8_t* ptr, size_t offset, Elem n) {
508     DCHECK_ALIGNED(ptr + offset, sizeof(n));
509     if (ptr != nullptr) {
510       *reinterpret_cast<Elem*>(ptr + offset) = n;
511     }
512     return offset + sizeof(n);
513   }
514 
515   template <typename Elem>
ReadFromBytes(const uint8_t * ptr,size_t offset,Elem * out)516   static size_t ReadFromBytes(const uint8_t* ptr, size_t offset, Elem* out) {
517     DCHECK(ptr != nullptr);
518     DCHECK_ALIGNED(ptr + offset, sizeof(*out));
519     *out = *reinterpret_cast<const Elem*>(ptr + offset);
520     return offset + sizeof(*out);
521   }
522 
523   Alloc allocfn_;  // Allocator function.
524   HashFn hashfn_;  // Hashing function.
525   EmptyFn emptyfn_;  // IsEmpty/SetEmpty function.
526   Pred pred_;  // Equals function.
527   size_t num_elements_;  // Number of inserted elements.
528   size_t num_buckets_;  // Number of hash table buckets.
529   size_t elements_until_expand_;  // Maxmimum number of elements until we expand the table.
530   bool owns_data_;  // If we own data_ and are responsible for freeing it.
531   T* data_;  // Backing storage.
532   double min_load_factor_;
533   double max_load_factor_;
534 };
535 
536 }  // namespace art
537 
538 #endif  // ART_RUNTIME_BASE_HASH_SET_H_
539