1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 /****************************************************************************************\
44 * K-Nearest Neighbors Classifier *
45 \****************************************************************************************/
46
47 // k Nearest Neighbors
CvKNearest()48 CvKNearest::CvKNearest()
49 {
50 samples = 0;
51 clear();
52 }
53
54
~CvKNearest()55 CvKNearest::~CvKNearest()
56 {
57 clear();
58 }
59
60
CvKNearest(const CvMat * _train_data,const CvMat * _responses,const CvMat * _sample_idx,bool _is_regression,int _max_k)61 CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
62 const CvMat* _sample_idx, bool _is_regression, int _max_k )
63 {
64 samples = 0;
65 train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
66 }
67
68
clear()69 void CvKNearest::clear()
70 {
71 while( samples )
72 {
73 CvVectors* next_samples = samples->next;
74 cvFree( &samples->data.fl );
75 cvFree( &samples );
76 samples = next_samples;
77 }
78 var_count = 0;
79 total = 0;
80 max_k = 0;
81 }
82
83
get_max_k() const84 int CvKNearest::get_max_k() const { return max_k; }
85
get_var_count() const86 int CvKNearest::get_var_count() const { return var_count; }
87
is_regression() const88 bool CvKNearest::is_regression() const { return regression; }
89
get_sample_count() const90 int CvKNearest::get_sample_count() const { return total; }
91
train(const CvMat * _train_data,const CvMat * _responses,const CvMat * _sample_idx,bool _is_regression,int _max_k,bool _update_base)92 bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
93 const CvMat* _sample_idx, bool _is_regression,
94 int _max_k, bool _update_base )
95 {
96 bool ok = false;
97 CvMat* responses = 0;
98
99 CV_FUNCNAME( "CvKNearest::train" );
100
101 __BEGIN__;
102
103 CvVectors* _samples;
104 float** _data;
105 int _count, _dims, _dims_all, _rsize;
106
107 if( !_update_base )
108 clear();
109
110 // Prepare training data and related parameters.
111 // Treat categorical responses as ordered - to prevent class label compression and
112 // to enable entering new classes in the updates
113 CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE,
114 _responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data,
115 &_count, &_dims, &_dims_all, &responses, 0, 0 ));
116
117 if( _update_base && _dims != var_count )
118 CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" );
119
120 if( !_update_base )
121 {
122 if( _max_k < 1 )
123 CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" );
124
125 regression = _is_regression;
126 var_count = _dims;
127 max_k = _max_k;
128 }
129
130 _rsize = _count*sizeof(float);
131 CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize ));
132 _samples->next = samples;
133 _samples->type = CV_32F;
134 _samples->data.fl = _data;
135 _samples->count = _count;
136 total += _count;
137
138 samples = _samples;
139 memcpy( _samples + 1, responses->data.fl, _rsize );
140
141 ok = true;
142
143 __END__;
144
145 return ok;
146 }
147
148
149
find_neighbors_direct(const CvMat * _samples,int k,int start,int end,float * neighbor_responses,const float ** neighbors,float * dist) const150 void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
151 float* neighbor_responses, const float** neighbors, float* dist ) const
152 {
153 int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
154 CvVectors* s = samples;
155
156 for( ; s != 0; s = s->next )
157 {
158 int n = s->count;
159 for( j = 0; j < n; j++ )
160 {
161 for( i = 0; i < count; i++ )
162 {
163 double sum = 0;
164 Cv32suf si;
165 const float* v = s->data.fl[j];
166 const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));
167 Cv32suf* dd = (Cv32suf*)(dist + i*k);
168 float* nr;
169 const float** nn;
170 int t, ii, ii1;
171
172 for( t = 0; t <= d - 4; t += 4 )
173 {
174 double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
175 double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
176 sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
177 }
178
179 for( ; t < d; t++ )
180 {
181 double t0 = u[t] - v[t];
182 sum += t0*t0;
183 }
184
185 si.f = (float)sum;
186 for( ii = k1-1; ii >= 0; ii-- )
187 if( si.i > dd[ii].i )
188 break;
189 if( ii >= k-1 )
190 continue;
191
192 nr = neighbor_responses + i*k;
193 nn = neighbors ? neighbors + (start + i)*k : 0;
194 for( ii1 = k2 - 1; ii1 > ii; ii1-- )
195 {
196 dd[ii1+1].i = dd[ii1].i;
197 nr[ii1+1] = nr[ii1];
198 if( nn ) nn[ii1+1] = nn[ii1];
199 }
200 dd[ii+1].i = si.i;
201 nr[ii+1] = ((float*)(s + 1))[j];
202 if( nn )
203 nn[ii+1] = v;
204 }
205 k1 = MIN( k1+1, k );
206 k2 = MIN( k1, k-1 );
207 }
208 }
209 }
210
211
write_results(int k,int k1,int start,int end,const float * neighbor_responses,const float * dist,CvMat * _results,CvMat * _neighbor_responses,CvMat * _dist,Cv32suf * sort_buf) const212 float CvKNearest::write_results( int k, int k1, int start, int end,
213 const float* neighbor_responses, const float* dist,
214 CvMat* _results, CvMat* _neighbor_responses,
215 CvMat* _dist, Cv32suf* sort_buf ) const
216 {
217 float result = 0.f;
218 int i, j, j1, count = end - start;
219 double inv_scale = 1./k1;
220 int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;
221
222 for( i = 0; i < count; i++ )
223 {
224 const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
225 float* dst;
226 float r;
227 if( _results || start+i == 0 )
228 {
229 if( regression )
230 {
231 double s = 0;
232 for( j = 0; j < k1; j++ )
233 s += nr[j].f;
234 r = (float)(s*inv_scale);
235 }
236 else
237 {
238 int prev_start = 0, best_count = 0, cur_count;
239 Cv32suf best_val;
240
241 for( j = 0; j < k1; j++ )
242 sort_buf[j].i = nr[j].i;
243
244 for( j = k1-1; j > 0; j-- )
245 {
246 bool swap_fl = false;
247 for( j1 = 0; j1 < j; j1++ )
248 if( sort_buf[j1].i > sort_buf[j1+1].i )
249 {
250 int t;
251 CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
252 swap_fl = true;
253 }
254 if( !swap_fl )
255 break;
256 }
257
258 best_val.i = 0;
259 for( j = 1; j <= k1; j++ )
260 if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
261 {
262 cur_count = j - prev_start;
263 if( best_count < cur_count )
264 {
265 best_count = cur_count;
266 best_val.i = sort_buf[j-1].i;
267 }
268 prev_start = j;
269 }
270 r = best_val.f;
271 }
272
273 if( start+i == 0 )
274 result = r;
275
276 if( _results )
277 _results->data.fl[(start + i)*rstep] = r;
278 }
279
280 if( _neighbor_responses )
281 {
282 dst = (float*)(_neighbor_responses->data.ptr +
283 (start + i)*_neighbor_responses->step);
284 for( j = 0; j < k1; j++ )
285 dst[j] = nr[j].f;
286 for( ; j < k; j++ )
287 dst[j] = 0.f;
288 }
289
290 if( _dist )
291 {
292 dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
293 for( j = 0; j < k1; j++ )
294 dst[j] = dist[j + i*k];
295 for( ; j < k; j++ )
296 dst[j] = 0.f;
297 }
298 }
299
300 return result;
301 }
302
303
304
find_nearest(const CvMat * _samples,int k,CvMat * _results,const float ** _neighbors,CvMat * _neighbor_responses,CvMat * _dist) const305 float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results,
306 const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const
307 {
308 float result = 0.f;
309 bool local_alloc = false;
310 float* buf = 0;
311 const int max_blk_count = 128, max_buf_sz = 1 << 12;
312
313 CV_FUNCNAME( "CvKNearest::find_nearest" );
314
315 __BEGIN__;
316
317 int i, count, count_scale, blk_count0, blk_count = 0, buf_sz, k1;
318
319 if( !samples )
320 CV_ERROR( CV_StsError, "The search tree must be constructed first using train method" );
321
322 if( !CV_IS_MAT(_samples) ||
323 CV_MAT_TYPE(_samples->type) != CV_32FC1 ||
324 _samples->cols != var_count )
325 CV_ERROR( CV_StsBadArg, "Input samples must be floating-point matrix (<num_samples>x<var_count>)" );
326
327 if( _results && (!CV_IS_MAT(_results) ||
328 _results->cols != 1 && _results->rows != 1 ||
329 _results->cols + _results->rows - 1 != _samples->rows) )
330 CV_ERROR( CV_StsBadArg,
331 "The results must be 1d vector containing as much elements as the number of samples" );
332
333 if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 &&
334 (CV_MAT_TYPE(_results->type) != CV_32SC1 || regression))
335 CV_ERROR( CV_StsUnsupportedFormat,
336 "The results must be floating-point or integer (in case of classification) vector" );
337
338 if( k < 1 || k > max_k )
339 CV_ERROR( CV_StsOutOfRange, "k must be within 1..max_k range" );
340
341 if( _neighbor_responses )
342 {
343 if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 ||
344 _neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k )
345 CV_ERROR( CV_StsBadArg,
346 "The neighbor responses (if present) must be floating-point matrix of <num_samples> x <k> size" );
347 }
348
349 if( _dist )
350 {
351 if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 ||
352 _dist->rows != _samples->rows || _dist->cols != k )
353 CV_ERROR( CV_StsBadArg,
354 "The distances from the neighbors (if present) must be floating-point matrix of <num_samples> x <k> size" );
355 }
356
357 count = _samples->rows;
358 count_scale = k*2*sizeof(float);
359 blk_count0 = MIN( count, max_blk_count );
360 buf_sz = MIN( blk_count0 * count_scale, max_buf_sz );
361 blk_count0 = MAX( buf_sz/count_scale, 1 );
362 blk_count0 += blk_count0 % 2;
363 blk_count0 = MIN( blk_count0, count );
364 buf_sz = blk_count0 * count_scale + k*sizeof(float);
365 k1 = get_sample_count();
366 k1 = MIN( k1, k );
367
368 if( buf_sz <= CV_MAX_LOCAL_SIZE )
369 {
370 buf = (float*)cvStackAlloc( buf_sz );
371 local_alloc = true;
372 }
373 else
374 CV_CALL( buf = (float*)cvAlloc( buf_sz ));
375
376 for( i = 0; i < count; i += blk_count )
377 {
378 blk_count = MIN( count - i, blk_count0 );
379 float* neighbor_responses = buf;
380 float* dist = buf + blk_count*k;
381 Cv32suf* sort_buf = (Cv32suf*)(dist + blk_count*k);
382
383 find_neighbors_direct( _samples, k, i, i + blk_count,
384 neighbor_responses, _neighbors, dist );
385
386 float r = write_results( k, k1, i, i + blk_count, neighbor_responses, dist,
387 _results, _neighbor_responses, _dist, sort_buf );
388 if( i == 0 )
389 result = r;
390 }
391
392 __END__;
393
394 if( !local_alloc )
395 cvFree( &buf );
396
397 return result;
398 }
399
400 /* End of file */
401
402