1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
18 
19 #include <cassert>
20 #include <cstdint>
21 #include <cstdlib>
22 #include <vector>
23 
24 namespace tensorflow {
25 namespace nearest_neighbor {
26 
27 // A simple binary heap. We use our own implementation because multiprobe for
28 // the cross-polytope hash interacts with the heap in a way so that about half
29 // of the insertion operations are guaranteed to be on top of the heap. We make
30 // use of this fact in the AugmentedHeap below.
31 
32 // HeapBase is a base class for both the SimpleHeap and AugmentedHeap below.
33 template <typename KeyType, typename DataType>
34 class HeapBase {
35  public:
36   class Item {
37    public:
38     KeyType key;
39     DataType data;
40 
Item()41     Item() {}
Item(const KeyType & k,const DataType & d)42     Item(const KeyType& k, const DataType& d) : key(k), data(d) {}
43 
44     bool operator<(const Item& i2) const { return key < i2.key; }
45   };
46 
ExtractMin(KeyType * key,DataType * data)47   void ExtractMin(KeyType* key, DataType* data) {
48     *key = v_[0].key;
49     *data = v_[0].data;
50     num_elements_ -= 1;
51     v_[0] = v_[num_elements_];
52     HeapDown(0);
53   }
54 
IsEmpty()55   bool IsEmpty() { return num_elements_ == 0; }
56 
57   // This method adds an element at the end of the internal array without
58   // "heapifying" the array afterwards. This is useful for setting up a heap
59   // where a single call to heapify at the end of the initial insertion
60   // operations suffices.
InsertUnsorted(const KeyType & key,const DataType & data)61   void InsertUnsorted(const KeyType& key, const DataType& data) {
62     if (v_.size() == static_cast<size_t>(num_elements_)) {
63       v_.push_back(Item(key, data));
64     } else {
65       v_[num_elements_].key = key;
66       v_[num_elements_].data = data;
67     }
68     num_elements_ += 1;
69   }
70 
Insert(const KeyType & key,const DataType & data)71   void Insert(const KeyType& key, const DataType& data) {
72     if (v_.size() == static_cast<size_t>(num_elements_)) {
73       v_.push_back(Item(key, data));
74     } else {
75       v_[num_elements_].key = key;
76       v_[num_elements_].data = data;
77     }
78     num_elements_ += 1;
79     HeapUp(num_elements_ - 1);
80   }
81 
Heapify()82   void Heapify() {
83     int_fast32_t rightmost = parent(num_elements_ - 1);
84     for (int_fast32_t cur_loc = rightmost; cur_loc >= 0; --cur_loc) {
85       HeapDown(cur_loc);
86     }
87   }
88 
Reset()89   void Reset() { num_elements_ = 0; }
90 
Resize(size_t new_size)91   void Resize(size_t new_size) { v_.resize(new_size); }
92 
93  protected:
lchild(int_fast32_t x)94   int_fast32_t lchild(int_fast32_t x) { return 2 * x + 1; }
95 
rchild(int_fast32_t x)96   int_fast32_t rchild(int_fast32_t x) { return 2 * x + 2; }
97 
parent(int_fast32_t x)98   int_fast32_t parent(int_fast32_t x) { return (x - 1) / 2; }
99 
SwapEntries(int_fast32_t a,int_fast32_t b)100   void SwapEntries(int_fast32_t a, int_fast32_t b) {
101     Item tmp = v_[a];
102     v_[a] = v_[b];
103     v_[b] = tmp;
104   }
105 
HeapUp(int_fast32_t cur_loc)106   void HeapUp(int_fast32_t cur_loc) {
107     int_fast32_t p = parent(cur_loc);
108     while (cur_loc > 0 && v_[p].key > v_[cur_loc].key) {
109       SwapEntries(p, cur_loc);
110       cur_loc = p;
111       p = parent(cur_loc);
112     }
113   }
114 
HeapDown(int_fast32_t cur_loc)115   void HeapDown(int_fast32_t cur_loc) {
116     while (true) {
117       int_fast32_t lc = lchild(cur_loc);
118       int_fast32_t rc = rchild(cur_loc);
119       if (lc >= num_elements_) {
120         return;
121       }
122 
123       if (v_[cur_loc].key <= v_[lc].key) {
124         if (rc >= num_elements_ || v_[cur_loc].key <= v_[rc].key) {
125           return;
126         } else {
127           SwapEntries(cur_loc, rc);
128           cur_loc = rc;
129         }
130       } else {
131         if (rc >= num_elements_ || v_[lc].key <= v_[rc].key) {
132           SwapEntries(cur_loc, lc);
133           cur_loc = lc;
134         } else {
135           SwapEntries(cur_loc, rc);
136           cur_loc = rc;
137         }
138       }
139     }
140   }
141 
142   std::vector<Item> v_;
143   int_fast32_t num_elements_ = 0;
144 };
145 
146 // A "simple" binary heap.
147 template <typename KeyType, typename DataType>
148 class SimpleHeap : public HeapBase<KeyType, DataType> {
149  public:
ReplaceTop(const KeyType & key,const DataType & data)150   void ReplaceTop(const KeyType& key, const DataType& data) {
151     this->v_[0].key = key;
152     this->v_[0].data = data;
153     this->HeapDown(0);
154   }
155 
MinKey()156   KeyType MinKey() { return this->v_[0].key; }
157 
GetData()158   std::vector<typename HeapBase<KeyType, DataType>::Item>& GetData() {
159     return this->v_;
160   }
161 };
162 
163 // An "augmented" heap that can hold an extra element that is guaranteed to
164 // be at the top of the heap. This is useful if a significant fraction of the
165 // insertion operations are guaranteed insertions at the top. However, the heap
166 // only stores at most one such special top element, i.e., the heap assumes
167 // that extract_min() is called at least once between successive calls to
168 // insert_guaranteed_top().
169 template <typename KeyType, typename DataType>
170 class AugmentedHeap : public HeapBase<KeyType, DataType> {
171  public:
ExtractMin(KeyType * key,DataType * data)172   void ExtractMin(KeyType* key, DataType* data) {
173     if (has_guaranteed_top_) {
174       has_guaranteed_top_ = false;
175       *key = guaranteed_top_.key;
176       *data = guaranteed_top_.data;
177     } else {
178       *key = this->v_[0].key;
179       *data = this->v_[0].data;
180       this->num_elements_ -= 1;
181       this->v_[0] = this->v_[this->num_elements_];
182       this->HeapDown(0);
183     }
184   }
185 
IsEmpty()186   bool IsEmpty() { return this->num_elements_ == 0 && !has_guaranteed_top_; }
187 
InsertGuaranteedTop(const KeyType & key,const DataType & data)188   void InsertGuaranteedTop(const KeyType& key, const DataType& data) {
189     assert(!has_guaranteed_top_);
190     has_guaranteed_top_ = true;
191     guaranteed_top_.key = key;
192     guaranteed_top_.data = data;
193   }
194 
Reset()195   void Reset() {
196     this->num_elements_ = 0;
197     has_guaranteed_top_ = false;
198   }
199 
200  protected:
201   typename HeapBase<KeyType, DataType>::Item guaranteed_top_;
202   bool has_guaranteed_top_ = false;
203 };
204 
205 }  // namespace nearest_neighbor
206 }  // namespace tensorflow
207 
208 #endif  // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
209