1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 #ifndef OPENCV_FLANN_RESULTSET_H
32 #define OPENCV_FLANN_RESULTSET_H
33 
34 #include <algorithm>
35 #include <cstring>
36 #include <iostream>
37 #include <limits>
38 #include <set>
39 #include <vector>
40 
41 namespace cvflann
42 {
43 
44 /* This record represents a branch point when finding neighbors in
45     the tree.  It contains a record of the minimum distance to the query
46     point, as well as the node at which the search resumes.
47  */
48 
49 template <typename T, typename DistanceType>
50 struct BranchStruct
51 {
52     T node;           /* Tree node at which search resumes */
53     DistanceType mindist;     /* Minimum distance to query for all nodes below. */
54 
BranchStructBranchStruct55     BranchStruct() {}
BranchStructBranchStruct56     BranchStruct(const T& aNode, DistanceType dist) : node(aNode), mindist(dist) {}
57 
58     bool operator<(const BranchStruct<T, DistanceType>& rhs) const
59     {
60         return mindist<rhs.mindist;
61     }
62 };
63 
64 
65 template <typename DistanceType>
66 class ResultSet
67 {
68 public:
~ResultSet()69     virtual ~ResultSet() {}
70 
71     virtual bool full() const = 0;
72 
73     virtual void addPoint(DistanceType dist, int index) = 0;
74 
75     virtual DistanceType worstDist() const = 0;
76 
77 };
78 
79 /**
80  * KNNSimpleResultSet does not ensure that the element it holds are unique.
81  * Is used in those cases where the nearest neighbour algorithm used does not
82  * attempt to insert the same element multiple times.
83  */
84 template <typename DistanceType>
85 class KNNSimpleResultSet : public ResultSet<DistanceType>
86 {
87     int* indices;
88     DistanceType* dists;
89     int capacity;
90     int count;
91     DistanceType worst_distance_;
92 
93 public:
KNNSimpleResultSet(int capacity_)94     KNNSimpleResultSet(int capacity_) : capacity(capacity_), count(0)
95     {
96     }
97 
init(int * indices_,DistanceType * dists_)98     void init(int* indices_, DistanceType* dists_)
99     {
100         indices = indices_;
101         dists = dists_;
102         count = 0;
103         worst_distance_ = (std::numeric_limits<DistanceType>::max)();
104         dists[capacity-1] = worst_distance_;
105     }
106 
size()107     size_t size() const
108     {
109         return count;
110     }
111 
full()112     bool full() const
113     {
114         return count == capacity;
115     }
116 
117 
addPoint(DistanceType dist,int index)118     void addPoint(DistanceType dist, int index)
119     {
120         if (dist >= worst_distance_) return;
121         int i;
122         for (i=count; i>0; --i) {
123 #ifdef FLANN_FIRST_MATCH
124             if ( (dists[i-1]>dist) || ((dist==dists[i-1])&&(indices[i-1]>index)) )
125 #else
126             if (dists[i-1]>dist)
127 #endif
128             {
129                 if (i<capacity) {
130                     dists[i] = dists[i-1];
131                     indices[i] = indices[i-1];
132                 }
133             }
134             else break;
135         }
136         if (count < capacity) ++count;
137         dists[i] = dist;
138         indices[i] = index;
139         worst_distance_ = dists[capacity-1];
140     }
141 
worstDist()142     DistanceType worstDist() const
143     {
144         return worst_distance_;
145     }
146 };
147 
148 /**
149  * K-Nearest neighbour result set. Ensures that the elements inserted are unique
150  */
151 template <typename DistanceType>
152 class KNNResultSet : public ResultSet<DistanceType>
153 {
154     int* indices;
155     DistanceType* dists;
156     int capacity;
157     int count;
158     DistanceType worst_distance_;
159 
160 public:
KNNResultSet(int capacity_)161     KNNResultSet(int capacity_) : capacity(capacity_), count(0)
162     {
163     }
164 
init(int * indices_,DistanceType * dists_)165     void init(int* indices_, DistanceType* dists_)
166     {
167         indices = indices_;
168         dists = dists_;
169         count = 0;
170         worst_distance_ = (std::numeric_limits<DistanceType>::max)();
171         dists[capacity-1] = worst_distance_;
172     }
173 
size()174     size_t size() const
175     {
176         return count;
177     }
178 
full()179     bool full() const
180     {
181         return count == capacity;
182     }
183 
184 
addPoint(DistanceType dist,int index)185     void addPoint(DistanceType dist, int index)
186     {
187         if (dist >= worst_distance_) return;
188         int i;
189         for (i = count; i > 0; --i) {
190 #ifdef FLANN_FIRST_MATCH
191             if ( (dists[i-1]<=dist) && ((dist!=dists[i-1])||(indices[i-1]<=index)) )
192 #else
193             if (dists[i-1]<=dist)
194 #endif
195             {
196                 // Check for duplicate indices
197                 int j = i - 1;
198                 while ((j >= 0) && (dists[j] == dist)) {
199                     if (indices[j] == index) {
200                         return;
201                     }
202                     --j;
203                 }
204                 break;
205             }
206         }
207 
208         if (count < capacity) ++count;
209         for (int j = count-1; j > i; --j) {
210             dists[j] = dists[j-1];
211             indices[j] = indices[j-1];
212         }
213         dists[i] = dist;
214         indices[i] = index;
215         worst_distance_ = dists[capacity-1];
216     }
217 
worstDist()218     DistanceType worstDist() const
219     {
220         return worst_distance_;
221     }
222 };
223 
224 
225 /**
226  * A result-set class used when performing a radius based search.
227  */
228 template <typename DistanceType>
229 class RadiusResultSet : public ResultSet<DistanceType>
230 {
231     DistanceType radius;
232     int* indices;
233     DistanceType* dists;
234     size_t capacity;
235     size_t count;
236 
237 public:
RadiusResultSet(DistanceType radius_,int * indices_,DistanceType * dists_,int capacity_)238     RadiusResultSet(DistanceType radius_, int* indices_, DistanceType* dists_, int capacity_) :
239         radius(radius_), indices(indices_), dists(dists_), capacity(capacity_)
240     {
241         init();
242     }
243 
~RadiusResultSet()244     ~RadiusResultSet()
245     {
246     }
247 
init()248     void init()
249     {
250         count = 0;
251     }
252 
size()253     size_t size() const
254     {
255         return count;
256     }
257 
full()258     bool full() const
259     {
260         return true;
261     }
262 
addPoint(DistanceType dist,int index)263     void addPoint(DistanceType dist, int index)
264     {
265         if (dist<radius) {
266             if ((capacity>0)&&(count < capacity)) {
267                 dists[count] = dist;
268                 indices[count] = index;
269             }
270             count++;
271         }
272     }
273 
worstDist()274     DistanceType worstDist() const
275     {
276         return radius;
277     }
278 
279 };
280 
281 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
282 
283 /** Class that holds the k NN neighbors
284  * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
285  */
286 template<typename DistanceType>
287 class UniqueResultSet : public ResultSet<DistanceType>
288 {
289 public:
290     struct DistIndex
291     {
DistIndexDistIndex292         DistIndex(DistanceType dist, unsigned int index) :
293             dist_(dist), index_(index)
294         {
295         }
296         bool operator<(const DistIndex dist_index) const
297         {
298             return (dist_ < dist_index.dist_) || ((dist_ == dist_index.dist_) && index_ < dist_index.index_);
299         }
300         DistanceType dist_;
301         unsigned int index_;
302     };
303 
304     /** Default cosntructor */
UniqueResultSet()305     UniqueResultSet() :
306         worst_distance_(std::numeric_limits<DistanceType>::max())
307     {
308     }
309 
310     /** Check the status of the set
311      * @return true if we have k NN
312      */
full()313     inline bool full() const
314     {
315         return is_full_;
316     }
317 
318     /** Remove all elements in the set
319      */
320     virtual void clear() = 0;
321 
322     /** Copy the set to two C arrays
323      * @param indices pointer to a C array of indices
324      * @param dist pointer to a C array of distances
325      * @param n_neighbors the number of neighbors to copy
326      */
327     virtual void copy(int* indices, DistanceType* dist, int n_neighbors = -1) const
328     {
329         if (n_neighbors < 0) {
330             for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
331                      dist_indices_.end(); dist_index != dist_index_end; ++dist_index, ++indices, ++dist) {
332                 *indices = dist_index->index_;
333                 *dist = dist_index->dist_;
334             }
335         }
336         else {
337             int i = 0;
338             for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
339                      dist_indices_.end(); (dist_index != dist_index_end) && (i < n_neighbors); ++dist_index, ++indices, ++dist, ++i) {
340                 *indices = dist_index->index_;
341                 *dist = dist_index->dist_;
342             }
343         }
344     }
345 
346     /** Copy the set to two C arrays but sort it according to the distance first
347      * @param indices pointer to a C array of indices
348      * @param dist pointer to a C array of distances
349      * @param n_neighbors the number of neighbors to copy
350      */
351     virtual void sortAndCopy(int* indices, DistanceType* dist, int n_neighbors = -1) const
352     {
353         copy(indices, dist, n_neighbors);
354     }
355 
356     /** The number of neighbors in the set
357      * @return
358      */
size()359     size_t size() const
360     {
361         return dist_indices_.size();
362     }
363 
364     /** The distance of the furthest neighbor
365      * If we don't have enough neighbors, it returns the max possible value
366      * @return
367      */
worstDist()368     inline DistanceType worstDist() const
369     {
370         return worst_distance_;
371     }
372 protected:
373     /** Flag to say if the set is full */
374     bool is_full_;
375 
376     /** The worst distance found so far */
377     DistanceType worst_distance_;
378 
379     /** The best candidates so far */
380     std::set<DistIndex> dist_indices_;
381 };
382 
383 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
384 
385 /** Class that holds the k NN neighbors
386  * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
387  */
388 template<typename DistanceType>
389 class KNNUniqueResultSet : public UniqueResultSet<DistanceType>
390 {
391 public:
392     /** Constructor
393      * @param capacity the number of neighbors to store at max
394      */
KNNUniqueResultSet(unsigned int capacity)395     KNNUniqueResultSet(unsigned int capacity) : capacity_(capacity)
396     {
397         this->is_full_ = false;
398         this->clear();
399     }
400 
401     /** Add a possible candidate to the best neighbors
402      * @param dist distance for that neighbor
403      * @param index index of that neighbor
404      */
addPoint(DistanceType dist,int index)405     inline void addPoint(DistanceType dist, int index)
406     {
407         // Don't do anything if we are worse than the worst
408         if (dist >= worst_distance_) return;
409         dist_indices_.insert(DistIndex(dist, index));
410 
411         if (is_full_) {
412             if (dist_indices_.size() > capacity_) {
413                 dist_indices_.erase(*dist_indices_.rbegin());
414                 worst_distance_ = dist_indices_.rbegin()->dist_;
415             }
416         }
417         else if (dist_indices_.size() == capacity_) {
418             is_full_ = true;
419             worst_distance_ = dist_indices_.rbegin()->dist_;
420         }
421     }
422 
423     /** Remove all elements in the set
424      */
clear()425     void clear()
426     {
427         dist_indices_.clear();
428         worst_distance_ = std::numeric_limits<DistanceType>::max();
429         is_full_ = false;
430     }
431 
432 protected:
433     typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
434     using UniqueResultSet<DistanceType>::is_full_;
435     using UniqueResultSet<DistanceType>::worst_distance_;
436     using UniqueResultSet<DistanceType>::dist_indices_;
437 
438     /** The number of neighbors to keep */
439     unsigned int capacity_;
440 };
441 
442 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
443 
444 /** Class that holds the radius nearest neighbors
445  * It is more accurate than RadiusResult as it is not limited in the number of neighbors
446  */
447 template<typename DistanceType>
448 class RadiusUniqueResultSet : public UniqueResultSet<DistanceType>
449 {
450 public:
451     /** Constructor
452      * @param radius the maximum distance of a neighbor
453      */
RadiusUniqueResultSet(DistanceType radius)454     RadiusUniqueResultSet(DistanceType radius) :
455         radius_(radius)
456     {
457         is_full_ = true;
458     }
459 
460     /** Add a possible candidate to the best neighbors
461      * @param dist distance for that neighbor
462      * @param index index of that neighbor
463      */
addPoint(DistanceType dist,int index)464     void addPoint(DistanceType dist, int index)
465     {
466         if (dist <= radius_) dist_indices_.insert(DistIndex(dist, index));
467     }
468 
469     /** Remove all elements in the set
470      */
clear()471     inline void clear()
472     {
473         dist_indices_.clear();
474     }
475 
476 
477     /** Check the status of the set
478      * @return alwys false
479      */
full()480     inline bool full() const
481     {
482         return true;
483     }
484 
485     /** The distance of the furthest neighbor
486      * If we don't have enough neighbors, it returns the max possible value
487      * @return
488      */
worstDist()489     inline DistanceType worstDist() const
490     {
491         return radius_;
492     }
493 private:
494     typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
495     using UniqueResultSet<DistanceType>::dist_indices_;
496     using UniqueResultSet<DistanceType>::is_full_;
497 
498     /** The furthest distance a neighbor can be */
499     DistanceType radius_;
500 };
501 
502 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
503 
504 /** Class that holds the k NN neighbors within a radius distance
505  */
506 template<typename DistanceType>
507 class KNNRadiusUniqueResultSet : public KNNUniqueResultSet<DistanceType>
508 {
509 public:
510     /** Constructor
511      * @param capacity the number of neighbors to store at max
512      * @param radius the maximum distance of a neighbor
513      */
KNNRadiusUniqueResultSet(unsigned int capacity,DistanceType radius)514     KNNRadiusUniqueResultSet(unsigned int capacity, DistanceType radius)
515     {
516         this->capacity_ = capacity;
517         this->radius_ = radius;
518         this->dist_indices_.reserve(capacity_);
519         this->clear();
520     }
521 
522     /** Remove all elements in the set
523      */
clear()524     void clear()
525     {
526         dist_indices_.clear();
527         worst_distance_ = radius_;
528         is_full_ = false;
529     }
530 private:
531     using KNNUniqueResultSet<DistanceType>::dist_indices_;
532     using KNNUniqueResultSet<DistanceType>::is_full_;
533     using KNNUniqueResultSet<DistanceType>::worst_distance_;
534 
535     /** The maximum number of neighbors to consider */
536     unsigned int capacity_;
537 
538     /** The maximum distance of a neighbor */
539     DistanceType radius_;
540 };
541 }
542 
543 #endif //OPENCV_FLANN_RESULTSET_H
544