1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 // * Redistribution's of source code must retain the above copyright notice,
21 // this list of conditions and the following disclaimer.
22 //
23 // * Redistribution's in binary form must reproduce the above copyright notice,
24 // this list of conditions and the following disclaimer in the documentation
25 // and/or other materials provided with the distribution.
26 //
27 // * The name of the copyright holders may not be used to endorse or promote products
28 // derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44 #include <ctype.h>
45
46 namespace cv {
47 namespace ml {
48
49 using std::vector;
50
TreeParams()51 TreeParams::TreeParams()
52 {
53 maxDepth = INT_MAX;
54 minSampleCount = 10;
55 regressionAccuracy = 0.01f;
56 useSurrogates = false;
57 maxCategories = 10;
58 CVFolds = 10;
59 use1SERule = true;
60 truncatePrunedTree = true;
61 priors = Mat();
62 }
63
TreeParams(int _maxDepth,int _minSampleCount,double _regressionAccuracy,bool _useSurrogates,int _maxCategories,int _CVFolds,bool _use1SERule,bool _truncatePrunedTree,const Mat & _priors)64 TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
65 double _regressionAccuracy, bool _useSurrogates,
66 int _maxCategories, int _CVFolds,
67 bool _use1SERule, bool _truncatePrunedTree,
68 const Mat& _priors)
69 {
70 maxDepth = _maxDepth;
71 minSampleCount = _minSampleCount;
72 regressionAccuracy = (float)_regressionAccuracy;
73 useSurrogates = _useSurrogates;
74 maxCategories = _maxCategories;
75 CVFolds = _CVFolds;
76 use1SERule = _use1SERule;
77 truncatePrunedTree = _truncatePrunedTree;
78 priors = _priors;
79 }
80
Node()81 DTrees::Node::Node()
82 {
83 classIdx = 0;
84 value = 0;
85 parent = left = right = split = defaultDir = -1;
86 }
87
Split()88 DTrees::Split::Split()
89 {
90 varIdx = 0;
91 inversed = false;
92 quality = 0.f;
93 next = -1;
94 c = 0.f;
95 subsetOfs = 0;
96 }
97
98
WorkData(const Ptr<TrainData> & _data)99 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
100 {
101 data = _data;
102 vector<int> subsampleIdx;
103 Mat sidx0 = _data->getTrainSampleIdx();
104 if( !sidx0.empty() )
105 {
106 sidx0.copyTo(sidx);
107 std::sort(sidx.begin(), sidx.end());
108 }
109 else
110 {
111 int n = _data->getNSamples();
112 setRangeVector(sidx, n);
113 }
114
115 maxSubsetSize = 0;
116 }
117
DTreesImpl()118 DTreesImpl::DTreesImpl() {}
~DTreesImpl()119 DTreesImpl::~DTreesImpl() {}
clear()120 void DTreesImpl::clear()
121 {
122 varIdx.clear();
123 compVarIdx.clear();
124 varType.clear();
125 catOfs.clear();
126 catMap.clear();
127 roots.clear();
128 nodes.clear();
129 splits.clear();
130 subsets.clear();
131 classLabels.clear();
132
133 w.release();
134 _isClassifier = false;
135 }
136
startTraining(const Ptr<TrainData> & data,int)137 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
138 {
139 clear();
140 w = makePtr<WorkData>(data);
141
142 Mat vtype = data->getVarType();
143 vtype.copyTo(varType);
144
145 data->getCatOfs().copyTo(catOfs);
146 data->getCatMap().copyTo(catMap);
147 data->getDefaultSubstValues().copyTo(missingSubst);
148
149 int nallvars = data->getNAllVars();
150
151 Mat vidx0 = data->getVarIdx();
152 if( !vidx0.empty() )
153 vidx0.copyTo(varIdx);
154 else
155 setRangeVector(varIdx, nallvars);
156
157 initCompVarIdx();
158
159 w->maxSubsetSize = 0;
160
161 int i, nvars = (int)varIdx.size();
162 for( i = 0; i < nvars; i++ )
163 w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
164
165 w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
166
167 data->getSampleWeights().copyTo(w->sample_weights);
168
169 _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
170
171 if( _isClassifier )
172 {
173 data->getNormCatResponses().copyTo(w->cat_responses);
174 data->getClassLabels().copyTo(classLabels);
175 int nclasses = (int)classLabels.size();
176
177 Mat class_weights = params.priors;
178 if( !class_weights.empty() )
179 {
180 if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
181 {
182 Mat temp;
183 class_weights.convertTo(temp, CV_64F);
184 class_weights = temp;
185 }
186 CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
187
188 int nsamples = (int)w->cat_responses.size();
189 const double* cw = class_weights.ptr<double>();
190 CV_Assert( (int)w->sample_weights.size() == nsamples );
191
192 for( i = 0; i < nsamples; i++ )
193 {
194 int ci = w->cat_responses[i];
195 CV_Assert( 0 <= ci && ci < nclasses );
196 w->sample_weights[i] *= cw[ci];
197 }
198 }
199 }
200 else
201 data->getResponses().copyTo(w->ord_responses);
202 }
203
204
initCompVarIdx()205 void DTreesImpl::initCompVarIdx()
206 {
207 int nallvars = (int)varType.size();
208 compVarIdx.assign(nallvars, -1);
209 int i, nvars = (int)varIdx.size(), prevIdx = -1;
210 for( i = 0; i < nvars; i++ )
211 {
212 int vi = varIdx[i];
213 CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
214 prevIdx = vi;
215 compVarIdx[vi] = i;
216 }
217 }
218
endTraining()219 void DTreesImpl::endTraining()
220 {
221 w.release();
222 }
223
train(const Ptr<TrainData> & trainData,int flags)224 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
225 {
226 startTraining(trainData, flags);
227 bool ok = addTree( w->sidx ) >= 0;
228 w.release();
229 endTraining();
230 return ok;
231 }
232
getActiveVars()233 const vector<int>& DTreesImpl::getActiveVars()
234 {
235 return varIdx;
236 }
237
addTree(const vector<int> & sidx)238 int DTreesImpl::addTree(const vector<int>& sidx )
239 {
240 size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
241
242 w->wnodes.reserve(n);
243 w->wsplits.reserve(n);
244 w->wsubsets.reserve(n*w->maxSubsetSize);
245 w->wnodes.clear();
246 w->wsplits.clear();
247 w->wsubsets.clear();
248
249 int cv_n = params.getCVFolds();
250
251 if( cv_n > 0 )
252 {
253 w->cv_Tn.resize(n*cv_n);
254 w->cv_node_error.resize(n*cv_n);
255 w->cv_node_risk.resize(n*cv_n);
256 }
257
258 // build the tree recursively
259 int w_root = addNodeAndTrySplit(-1, sidx);
260 int maxdepth = INT_MAX;//pruneCV(root);
261
262 int w_nidx = w_root, pidx = -1, depth = 0;
263 int root = (int)nodes.size();
264
265 for(;;)
266 {
267 const WNode& wnode = w->wnodes[w_nidx];
268 Node node;
269 node.parent = pidx;
270 node.classIdx = wnode.class_idx;
271 node.value = wnode.value;
272 node.defaultDir = wnode.defaultDir;
273
274 int wsplit_idx = wnode.split;
275 if( wsplit_idx >= 0 )
276 {
277 const WSplit& wsplit = w->wsplits[wsplit_idx];
278 Split split;
279 split.c = wsplit.c;
280 split.quality = wsplit.quality;
281 split.inversed = wsplit.inversed;
282 split.varIdx = wsplit.varIdx;
283 split.subsetOfs = -1;
284 if( wsplit.subsetOfs >= 0 )
285 {
286 int ssize = getSubsetSize(split.varIdx);
287 split.subsetOfs = (int)subsets.size();
288 subsets.resize(split.subsetOfs + ssize);
289 // This check verifies that subsets index is in the correct range
290 // as in case ssize == 0 no real resize performed.
291 // Thus memory kept safe.
292 // Also this skips useless memcpy call when size parameter is zero
293 if(ssize > 0)
294 {
295 memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
296 }
297 }
298 node.split = (int)splits.size();
299 splits.push_back(split);
300 }
301 int nidx = (int)nodes.size();
302 nodes.push_back(node);
303 if( pidx >= 0 )
304 {
305 int w_pidx = w->wnodes[w_nidx].parent;
306 if( w->wnodes[w_pidx].left == w_nidx )
307 {
308 nodes[pidx].left = nidx;
309 }
310 else
311 {
312 CV_Assert(w->wnodes[w_pidx].right == w_nidx);
313 nodes[pidx].right = nidx;
314 }
315 }
316
317 if( wnode.left >= 0 && depth+1 < maxdepth )
318 {
319 w_nidx = wnode.left;
320 pidx = nidx;
321 depth++;
322 }
323 else
324 {
325 int w_pidx = wnode.parent;
326 while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
327 {
328 w_nidx = w_pidx;
329 w_pidx = w->wnodes[w_pidx].parent;
330 nidx = pidx;
331 pidx = nodes[pidx].parent;
332 depth--;
333 }
334
335 if( w_pidx < 0 )
336 break;
337
338 w_nidx = w->wnodes[w_pidx].right;
339 CV_Assert( w_nidx >= 0 );
340 }
341 }
342 roots.push_back(root);
343 return root;
344 }
345
setDParams(const TreeParams & _params)346 void DTreesImpl::setDParams(const TreeParams& _params)
347 {
348 params = _params;
349 }
350
addNodeAndTrySplit(int parent,const vector<int> & sidx)351 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
352 {
353 w->wnodes.push_back(WNode());
354 int nidx = (int)(w->wnodes.size() - 1);
355 WNode& node = w->wnodes.back();
356
357 node.parent = parent;
358 node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
359 int nfolds = params.getCVFolds();
360
361 if( nfolds > 0 )
362 {
363 w->cv_Tn.resize((nidx+1)*nfolds);
364 w->cv_node_error.resize((nidx+1)*nfolds);
365 w->cv_node_risk.resize((nidx+1)*nfolds);
366 }
367
368 int i, n = node.sample_count = (int)sidx.size();
369 bool can_split = true;
370 vector<int> sleft, sright;
371
372 calcValue( nidx, sidx );
373
374 if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
375 can_split = false;
376 else if( _isClassifier )
377 {
378 const int* responses = &w->cat_responses[0];
379 const int* s = &sidx[0];
380 int first = responses[s[0]];
381 for( i = 1; i < n; i++ )
382 if( responses[s[i]] != first )
383 break;
384 if( i == n )
385 can_split = false;
386 }
387 else
388 {
389 if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
390 can_split = false;
391 }
392
393 if( can_split )
394 node.split = findBestSplit( sidx );
395
396 //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
397
398 if( node.split >= 0 )
399 {
400 node.defaultDir = calcDir( node.split, sidx, sleft, sright );
401 if( params.useSurrogates )
402 CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
403
404 int left = addNodeAndTrySplit( nidx, sleft );
405 int right = addNodeAndTrySplit( nidx, sright );
406 w->wnodes[nidx].left = left;
407 w->wnodes[nidx].right = right;
408 CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
409 }
410
411 return nidx;
412 }
413
findBestSplit(const vector<int> & _sidx)414 int DTreesImpl::findBestSplit( const vector<int>& _sidx )
415 {
416 const vector<int>& activeVars = getActiveVars();
417 int splitidx = -1;
418 int vi_, nv = (int)activeVars.size();
419 AutoBuffer<int> buf(w->maxSubsetSize*2);
420 int *subset = buf, *best_subset = subset + w->maxSubsetSize;
421 WSplit split, best_split;
422 best_split.quality = 0.;
423
424 for( vi_ = 0; vi_ < nv; vi_++ )
425 {
426 int vi = activeVars[vi_];
427 if( varType[vi] == VAR_CATEGORICAL )
428 {
429 if( _isClassifier )
430 split = findSplitCatClass(vi, _sidx, 0, subset);
431 else
432 split = findSplitCatReg(vi, _sidx, 0, subset);
433 }
434 else
435 {
436 if( _isClassifier )
437 split = findSplitOrdClass(vi, _sidx, 0);
438 else
439 split = findSplitOrdReg(vi, _sidx, 0);
440 }
441 if( split.quality > best_split.quality )
442 {
443 best_split = split;
444 std::swap(subset, best_subset);
445 }
446 }
447
448 if( best_split.quality > 0 )
449 {
450 int best_vi = best_split.varIdx;
451 CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
452 int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
453 w->wsubsets.resize(prevsz + ssize);
454 for( i = 0; i < ssize; i++ )
455 w->wsubsets[prevsz + i] = best_subset[i];
456 best_split.subsetOfs = prevsz;
457 w->wsplits.push_back(best_split);
458 splitidx = (int)(w->wsplits.size()-1);
459 }
460
461 return splitidx;
462 }
463
calcValue(int nidx,const vector<int> & _sidx)464 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
465 {
466 WNode* node = &w->wnodes[nidx];
467 int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
468 int m = (int)classLabels.size();
469
470 cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
471
472 if( cv_n > 0 )
473 {
474 size_t sz = w->cv_Tn.size();
475 w->cv_Tn.resize(sz + cv_n);
476 w->cv_node_risk.resize(sz + cv_n);
477 w->cv_node_error.resize(sz + cv_n);
478 }
479
480 if( _isClassifier )
481 {
482 // in case of classification tree:
483 // * node value is the label of the class that has the largest weight in the node.
484 // * node risk is the weighted number of misclassified samples,
485 // * j-th cross-validation fold value and risk are calculated as above,
486 // but using the samples with cv_labels(*)!=j.
487 // * j-th cross-validation fold error is calculated as the weighted number of
488 // misclassified samples with cv_labels(*)==j.
489
490 // compute the number of instances of each class
491 double* cls_count = buf;
492 double* cv_cls_count = cls_count + m;
493
494 double max_val = -1, total_weight = 0;
495 int max_k = -1;
496
497 for( k = 0; k < m; k++ )
498 cls_count[k] = 0;
499
500 if( cv_n == 0 )
501 {
502 for( i = 0; i < n; i++ )
503 {
504 int si = _sidx[i];
505 cls_count[w->cat_responses[si]] += w->sample_weights[si];
506 }
507 }
508 else
509 {
510 for( j = 0; j < cv_n; j++ )
511 for( k = 0; k < m; k++ )
512 cv_cls_count[j*m + k] = 0;
513
514 for( i = 0; i < n; i++ )
515 {
516 int si = _sidx[i];
517 j = w->cv_labels[si]; k = w->cat_responses[si];
518 cv_cls_count[j*m + k] += w->sample_weights[si];
519 }
520
521 for( j = 0; j < cv_n; j++ )
522 for( k = 0; k < m; k++ )
523 cls_count[k] += cv_cls_count[j*m + k];
524 }
525
526 for( k = 0; k < m; k++ )
527 {
528 double val = cls_count[k];
529 total_weight += val;
530 if( max_val < val )
531 {
532 max_val = val;
533 max_k = k;
534 }
535 }
536
537 node->class_idx = max_k;
538 node->value = classLabels[max_k];
539 node->node_risk = total_weight - max_val;
540
541 for( j = 0; j < cv_n; j++ )
542 {
543 double sum_k = 0, sum = 0, max_val_k = 0;
544 max_val = -1; max_k = -1;
545
546 for( k = 0; k < m; k++ )
547 {
548 double val_k = cv_cls_count[j*m + k];
549 double val = cls_count[k] - val_k;
550 sum_k += val_k;
551 sum += val;
552 if( max_val < val )
553 {
554 max_val = val;
555 max_val_k = val_k;
556 max_k = k;
557 }
558 }
559
560 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
561 w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
562 w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
563 }
564 }
565 else
566 {
567 // in case of regression tree:
568 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
569 // n is the number of samples in the node.
570 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
571 // * j-th cross-validation fold value and risk are calculated as above,
572 // but using the samples with cv_labels(*)!=j.
573 // * j-th cross-validation fold error is calculated
574 // using samples with cv_labels(*)==j as the test subset:
575 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
576 // where node_value_j is the node value calculated
577 // as described in the previous bullet, and summation is done
578 // over the samples with cv_labels(*)==j.
579 double sum = 0, sum2 = 0, sumw = 0;
580
581 if( cv_n == 0 )
582 {
583 for( i = 0; i < n; i++ )
584 {
585 int si = _sidx[i];
586 double wval = w->sample_weights[si];
587 double t = w->ord_responses[si];
588 sum += t*wval;
589 sum2 += t*t*wval;
590 sumw += wval;
591 }
592 }
593 else
594 {
595 double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
596 double* cv_count = (double*)(cv_sum2 + cv_n);
597
598 for( j = 0; j < cv_n; j++ )
599 {
600 cv_sum[j] = cv_sum2[j] = 0.;
601 cv_count[j] = 0;
602 }
603
604 for( i = 0; i < n; i++ )
605 {
606 int si = _sidx[i];
607 j = w->cv_labels[si];
608 double wval = w->sample_weights[si];
609 double t = w->ord_responses[si];
610 cv_sum[j] += t*wval;
611 cv_sum2[j] += t*t*wval;
612 cv_count[j] += wval;
613 }
614
615 for( j = 0; j < cv_n; j++ )
616 {
617 sum += cv_sum[j];
618 sum2 += cv_sum2[j];
619 sumw += cv_count[j];
620 }
621
622 for( j = 0; j < cv_n; j++ )
623 {
624 double s = sum - cv_sum[j], si = sum - s;
625 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
626 double c = cv_count[j], ci = sumw - c;
627 double r = si/std::max(ci, DBL_EPSILON);
628 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
629 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
630 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
631 }
632 }
633
634 node->node_risk = sum2 - (sum/sumw)*sum;
635 node->value = sum/sumw;
636 }
637 }
638
findSplitOrdClass(int vi,const vector<int> & _sidx,double initQuality)639 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
640 {
641 const double epsilon = FLT_EPSILON*2;
642 int n = (int)_sidx.size();
643 int m = (int)classLabels.size();
644
645 cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
646 const int* sidx = &_sidx[0];
647 const int* responses = &w->cat_responses[0];
648 const double* weights = &w->sample_weights[0];
649 double* lcw = (double*)(uchar*)buf;
650 double* rcw = lcw + m;
651 float* values = (float*)(rcw + m);
652 int* sorted_idx = (int*)(values + n);
653 int i, best_i = -1;
654 double best_val = initQuality;
655
656 for( i = 0; i < m; i++ )
657 lcw[i] = rcw[i] = 0.;
658
659 w->data->getValues( vi, _sidx, values );
660
661 for( i = 0; i < n; i++ )
662 {
663 sorted_idx[i] = i;
664 int si = sidx[i];
665 rcw[responses[si]] += weights[si];
666 }
667
668 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
669
670 double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
671 for( i = 0; i < m; i++ )
672 {
673 double wval = rcw[i];
674 R += wval;
675 rsum2 += wval*wval;
676 }
677
678 for( i = 0; i < n - 1; i++ )
679 {
680 int curr = sorted_idx[i];
681 int next = sorted_idx[i+1];
682 int si = sidx[curr];
683 double wval = weights[si], w2 = wval*wval;
684 L += wval; R -= wval;
685 int idx = responses[si];
686 double lv = lcw[idx], rv = rcw[idx];
687 lsum2 += 2*lv*wval + w2;
688 rsum2 -= 2*rv*wval - w2;
689 lcw[idx] = lv + wval; rcw[idx] = rv - wval;
690
691 if( values[curr] + epsilon < values[next] )
692 {
693 double val = (lsum2*R + rsum2*L)/(L*R);
694 if( best_val < val )
695 {
696 best_val = val;
697 best_i = i;
698 }
699 }
700 }
701
702 WSplit split;
703 if( best_i >= 0 )
704 {
705 split.varIdx = vi;
706 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
707 split.inversed = false;
708 split.quality = (float)best_val;
709 }
710 return split;
711 }
712
713 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
clusterCategories(const double * vectors,int n,int m,double * csums,int k,int * labels)714 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
715 {
716 int iters = 0, max_iters = 100;
717 int i, j, idx;
718 cv::AutoBuffer<double> buf(n + k);
719 double *v_weights = buf, *c_weights = buf + n;
720 bool modified = true;
721 RNG r((uint64)-1);
722
723 // assign labels randomly
724 for( i = 0; i < n; i++ )
725 {
726 double sum = 0;
727 const double* v = vectors + i*m;
728 labels[i] = i < k ? i : r.uniform(0, k);
729
730 // compute weight of each vector
731 for( j = 0; j < m; j++ )
732 sum += v[j];
733 v_weights[i] = sum ? 1./sum : 0.;
734 }
735
736 for( i = 0; i < n; i++ )
737 {
738 int i1 = r.uniform(0, n);
739 int i2 = r.uniform(0, n);
740 std::swap( labels[i1], labels[i2] );
741 }
742
743 for( iters = 0; iters <= max_iters; iters++ )
744 {
745 // calculate csums
746 for( i = 0; i < k; i++ )
747 {
748 for( j = 0; j < m; j++ )
749 csums[i*m + j] = 0;
750 }
751
752 for( i = 0; i < n; i++ )
753 {
754 const double* v = vectors + i*m;
755 double* s = csums + labels[i]*m;
756 for( j = 0; j < m; j++ )
757 s[j] += v[j];
758 }
759
760 // exit the loop here, when we have up-to-date csums
761 if( iters == max_iters || !modified )
762 break;
763
764 modified = false;
765
766 // calculate weight of each cluster
767 for( i = 0; i < k; i++ )
768 {
769 const double* s = csums + i*m;
770 double sum = 0;
771 for( j = 0; j < m; j++ )
772 sum += s[j];
773 c_weights[i] = sum ? 1./sum : 0;
774 }
775
776 // now for each vector determine the closest cluster
777 for( i = 0; i < n; i++ )
778 {
779 const double* v = vectors + i*m;
780 double alpha = v_weights[i];
781 double min_dist2 = DBL_MAX;
782 int min_idx = -1;
783
784 for( idx = 0; idx < k; idx++ )
785 {
786 const double* s = csums + idx*m;
787 double dist2 = 0., beta = c_weights[idx];
788 for( j = 0; j < m; j++ )
789 {
790 double t = v[j]*alpha - s[j]*beta;
791 dist2 += t*t;
792 }
793 if( min_dist2 > dist2 )
794 {
795 min_dist2 = dist2;
796 min_idx = idx;
797 }
798 }
799
800 if( min_idx != labels[i] )
801 modified = true;
802 labels[i] = min_idx;
803 }
804 }
805 }
806
findSplitCatClass(int vi,const vector<int> & _sidx,double initQuality,int * subset)807 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
808 double initQuality, int* subset )
809 {
810 int _mi = getCatCount(vi), mi = _mi;
811 int n = (int)_sidx.size();
812 int m = (int)classLabels.size();
813
814 int base_size = m*(3 + mi) + mi + 1;
815 if( m > 2 && mi > params.getMaxCategories() )
816 base_size += m*std::min(params.getMaxCategories(), n) + mi;
817 else
818 base_size += mi;
819 AutoBuffer<double> buf(base_size + n);
820
821 double* lc = (double*)buf;
822 double* rc = lc + m;
823 double* _cjk = rc + m*2, *cjk = _cjk;
824 double* c_weights = cjk + m*mi;
825
826 int* labels = (int*)(buf + base_size);
827 w->data->getNormCatValues(vi, _sidx, labels);
828 const int* responses = &w->cat_responses[0];
829 const double* weights = &w->sample_weights[0];
830
831 int* cluster_labels = 0;
832 double** dbl_ptr = 0;
833 int i, j, k, si, idx;
834 double L = 0, R = 0;
835 double best_val = initQuality;
836 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
837
838 // init array of counters:
839 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
840 for( j = -1; j < mi; j++ )
841 for( k = 0; k < m; k++ )
842 cjk[j*m + k] = 0;
843
844 for( i = 0; i < n; i++ )
845 {
846 si = _sidx[i];
847 j = labels[i];
848 k = responses[si];
849 cjk[j*m + k] += weights[si];
850 }
851
852 if( m > 2 )
853 {
854 if( mi > params.getMaxCategories() )
855 {
856 mi = std::min(params.getMaxCategories(), n);
857 cjk = c_weights + _mi;
858 cluster_labels = (int*)(cjk + m*mi);
859 clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
860 }
861 subset_i = 1;
862 subset_n = 1 << mi;
863 }
864 else
865 {
866 assert( m == 2 );
867 dbl_ptr = (double**)(c_weights + _mi);
868 for( j = 0; j < mi; j++ )
869 dbl_ptr[j] = cjk + j*2 + 1;
870 std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
871 subset_i = 0;
872 subset_n = mi;
873 }
874
875 for( k = 0; k < m; k++ )
876 {
877 double sum = 0;
878 for( j = 0; j < mi; j++ )
879 sum += cjk[j*m + k];
880 CV_Assert(sum > 0);
881 rc[k] = sum;
882 lc[k] = 0;
883 }
884
885 for( j = 0; j < mi; j++ )
886 {
887 double sum = 0;
888 for( k = 0; k < m; k++ )
889 sum += cjk[j*m + k];
890 c_weights[j] = sum;
891 R += c_weights[j];
892 }
893
894 for( ; subset_i < subset_n; subset_i++ )
895 {
896 double lsum2 = 0, rsum2 = 0;
897
898 if( m == 2 )
899 idx = (int)(dbl_ptr[subset_i] - cjk)/2;
900 else
901 {
902 int graycode = (subset_i>>1)^subset_i;
903 int diff = graycode ^ prevcode;
904
905 // determine index of the changed bit.
906 Cv32suf u;
907 idx = diff >= (1 << 16) ? 16 : 0;
908 u.f = (float)(((diff >> 16) | diff) & 65535);
909 idx += (u.i >> 23) - 127;
910 subtract = graycode < prevcode;
911 prevcode = graycode;
912 }
913
914 double* crow = cjk + idx*m;
915 double weight = c_weights[idx];
916 if( weight < FLT_EPSILON )
917 continue;
918
919 if( !subtract )
920 {
921 for( k = 0; k < m; k++ )
922 {
923 double t = crow[k];
924 double lval = lc[k] + t;
925 double rval = rc[k] - t;
926 lsum2 += lval*lval;
927 rsum2 += rval*rval;
928 lc[k] = lval; rc[k] = rval;
929 }
930 L += weight;
931 R -= weight;
932 }
933 else
934 {
935 for( k = 0; k < m; k++ )
936 {
937 double t = crow[k];
938 double lval = lc[k] - t;
939 double rval = rc[k] + t;
940 lsum2 += lval*lval;
941 rsum2 += rval*rval;
942 lc[k] = lval; rc[k] = rval;
943 }
944 L -= weight;
945 R += weight;
946 }
947
948 if( L > FLT_EPSILON && R > FLT_EPSILON )
949 {
950 double val = (lsum2*R + rsum2*L)/(L*R);
951 if( best_val < val )
952 {
953 best_val = val;
954 best_subset = subset_i;
955 }
956 }
957 }
958
959 WSplit split;
960 if( best_subset >= 0 )
961 {
962 split.varIdx = vi;
963 split.quality = (float)best_val;
964 memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
965 if( m == 2 )
966 {
967 for( i = 0; i <= best_subset; i++ )
968 {
969 idx = (int)(dbl_ptr[i] - cjk) >> 1;
970 subset[idx >> 5] |= 1 << (idx & 31);
971 }
972 }
973 else
974 {
975 for( i = 0; i < _mi; i++ )
976 {
977 idx = cluster_labels ? cluster_labels[i] : i;
978 if( best_subset & (1 << idx) )
979 subset[i >> 5] |= 1 << (i & 31);
980 }
981 }
982 }
983 return split;
984 }
985
findSplitOrdReg(int vi,const vector<int> & _sidx,double initQuality)986 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
987 {
988 const float epsilon = FLT_EPSILON*2;
989 const double* weights = &w->sample_weights[0];
990 int n = (int)_sidx.size();
991
992 AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
993
994 float* values = (float*)(uchar*)buf;
995 int* sorted_idx = (int*)(values + n);
996 w->data->getValues(vi, _sidx, values);
997 const double* responses = &w->ord_responses[0];
998
999 int i, si, best_i = -1;
1000 double L = 0, R = 0;
1001 double best_val = initQuality, lsum = 0, rsum = 0;
1002
1003 for( i = 0; i < n; i++ )
1004 {
1005 sorted_idx[i] = i;
1006 si = _sidx[i];
1007 R += weights[si];
1008 rsum += weights[si]*responses[si];
1009 }
1010
1011 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1012
1013 // find the optimal split
1014 for( i = 0; i < n - 1; i++ )
1015 {
1016 int curr = sorted_idx[i];
1017 int next = sorted_idx[i+1];
1018 si = _sidx[curr];
1019 double wval = weights[si];
1020 double t = responses[si]*wval;
1021 L += wval; R -= wval;
1022 lsum += t; rsum -= t;
1023
1024 if( values[curr] + epsilon < values[next] )
1025 {
1026 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1027 if( best_val < val )
1028 {
1029 best_val = val;
1030 best_i = i;
1031 }
1032 }
1033 }
1034
1035 WSplit split;
1036 if( best_i >= 0 )
1037 {
1038 split.varIdx = vi;
1039 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1040 split.inversed = false;
1041 split.quality = (float)best_val;
1042 }
1043 return split;
1044 }
1045
findSplitCatReg(int vi,const vector<int> & _sidx,double initQuality,int * subset)1046 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1047 double initQuality, int* subset )
1048 {
1049 const double* weights = &w->sample_weights[0];
1050 const double* responses = &w->ord_responses[0];
1051 int n = (int)_sidx.size();
1052 int mi = getCatCount(vi);
1053
1054 AutoBuffer<double> buf(3*mi + 3 + n);
1055 double* sum = (double*)buf + 1;
1056 double* counts = sum + mi + 1;
1057 double** sum_ptr = (double**)(counts + mi);
1058 int* cat_labels = (int*)(sum_ptr + mi);
1059
1060 w->data->getNormCatValues(vi, _sidx, cat_labels);
1061
1062 double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1063 int i, si, best_subset = -1, subset_i;
1064
1065 for( i = -1; i < mi; i++ )
1066 sum[i] = counts[i] = 0;
1067
1068 // calculate sum response and weight of each category of the input var
1069 for( i = 0; i < n; i++ )
1070 {
1071 int idx = cat_labels[i];
1072 si = _sidx[i];
1073 double wval = weights[si];
1074 sum[idx] += responses[si]*wval;
1075 counts[idx] += wval;
1076 }
1077
1078 // calculate average response in each category
1079 for( i = 0; i < mi; i++ )
1080 {
1081 R += counts[i];
1082 rsum += sum[i];
1083 sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1084 sum_ptr[i] = sum + i;
1085 }
1086
1087 std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1088
1089 // revert back to unnormalized sums
1090 // (there should be a very little loss in accuracy)
1091 for( i = 0; i < mi; i++ )
1092 sum[i] *= counts[i];
1093
1094 for( subset_i = 0; subset_i < mi-1; subset_i++ )
1095 {
1096 int idx = (int)(sum_ptr[subset_i] - sum);
1097 double ni = counts[idx];
1098
1099 if( ni > FLT_EPSILON )
1100 {
1101 double s = sum[idx];
1102 lsum += s; L += ni;
1103 rsum -= s; R -= ni;
1104
1105 if( L > FLT_EPSILON && R > FLT_EPSILON )
1106 {
1107 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1108 if( best_val < val )
1109 {
1110 best_val = val;
1111 best_subset = subset_i;
1112 }
1113 }
1114 }
1115 }
1116
1117 WSplit split;
1118 if( best_subset >= 0 )
1119 {
1120 split.varIdx = vi;
1121 split.quality = (float)best_val;
1122 memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1123 for( i = 0; i <= best_subset; i++ )
1124 {
1125 int idx = (int)(sum_ptr[i] - sum);
1126 subset[idx >> 5] |= 1 << (idx & 31);
1127 }
1128 }
1129 return split;
1130 }
1131
calcDir(int splitidx,const vector<int> & _sidx,vector<int> & _sleft,vector<int> & _sright)1132 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1133 vector<int>& _sleft, vector<int>& _sright )
1134 {
1135 WSplit split = w->wsplits[splitidx];
1136 int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1137 _sleft.reserve(n);
1138 _sright.reserve(n);
1139 _sleft.clear();
1140 _sright.clear();
1141
1142 AutoBuffer<float> buf(n);
1143 int mi = getCatCount(vi);
1144 double wleft = 0, wright = 0;
1145 const double* weights = &w->sample_weights[0];
1146
1147 if( mi <= 0 ) // split on an ordered variable
1148 {
1149 float c = split.c;
1150 float* values = buf;
1151 w->data->getValues(vi, _sidx, values);
1152
1153 for( i = 0; i < n; i++ )
1154 {
1155 si = _sidx[i];
1156 if( values[i] <= c )
1157 {
1158 _sleft.push_back(si);
1159 wleft += weights[si];
1160 }
1161 else
1162 {
1163 _sright.push_back(si);
1164 wright += weights[si];
1165 }
1166 }
1167 }
1168 else
1169 {
1170 const int* subset = &w->wsubsets[split.subsetOfs];
1171 int* cat_labels = (int*)(float*)buf;
1172 w->data->getNormCatValues(vi, _sidx, cat_labels);
1173
1174 for( i = 0; i < n; i++ )
1175 {
1176 si = _sidx[i];
1177 unsigned u = cat_labels[i];
1178 if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1179 {
1180 _sleft.push_back(si);
1181 wleft += weights[si];
1182 }
1183 else
1184 {
1185 _sright.push_back(si);
1186 wright += weights[si];
1187 }
1188 }
1189 }
1190 CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1191 return wleft > wright ? -1 : 1;
1192 }
1193
pruneCV(int root)1194 int DTreesImpl::pruneCV( int root )
1195 {
1196 vector<double> ab;
1197
1198 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1199 // 2. choose the best tree index (if need, apply 1SE rule).
1200 // 3. store the best index and cut the branches.
1201
1202 int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1203 // currently, 1SE for regression is not implemented
1204 bool use_1se = params.use1SERule != 0 && _isClassifier;
1205 double min_err = 0, min_err_se = 0;
1206 int min_idx = -1;
1207
1208 // build the main tree sequence, calculate alpha's
1209 for(;;tree_count++)
1210 {
1211 double min_alpha = updateTreeRNC(root, tree_count, -1);
1212 if( cutTree(root, tree_count, -1, min_alpha) )
1213 break;
1214
1215 ab.push_back(min_alpha);
1216 }
1217
1218 if( tree_count > 0 )
1219 {
1220 ab[0] = 0.;
1221
1222 for( ti = 1; ti < tree_count-1; ti++ )
1223 ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1224 ab[tree_count-1] = DBL_MAX*0.5;
1225
1226 Mat err_jk(cv_n, tree_count, CV_64F);
1227
1228 for( j = 0; j < cv_n; j++ )
1229 {
1230 int tj = 0, tk = 0;
1231 for( ; tj < tree_count; tj++ )
1232 {
1233 double min_alpha = updateTreeRNC(root, tj, j);
1234 if( cutTree(root, tj, j, min_alpha) )
1235 min_alpha = DBL_MAX;
1236
1237 for( ; tk < tree_count; tk++ )
1238 {
1239 if( ab[tk] > min_alpha )
1240 break;
1241 err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1242 }
1243 }
1244 }
1245
1246 for( ti = 0; ti < tree_count; ti++ )
1247 {
1248 double sum_err = 0;
1249 for( j = 0; j < cv_n; j++ )
1250 sum_err += err_jk.at<double>(j, ti);
1251 if( ti == 0 || sum_err < min_err )
1252 {
1253 min_err = sum_err;
1254 min_idx = ti;
1255 if( use_1se )
1256 min_err_se = sqrt( sum_err*(n - sum_err) );
1257 }
1258 else if( sum_err < min_err + min_err_se )
1259 min_idx = ti;
1260 }
1261 }
1262
1263 return min_idx;
1264 }
1265
updateTreeRNC(int root,double T,int fold)1266 double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1267 {
1268 int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1269 double min_alpha = DBL_MAX;
1270
1271 for(;;)
1272 {
1273 WNode *node = 0, *parent = 0;
1274
1275 for(;;)
1276 {
1277 node = &w->wnodes[nidx];
1278 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1279 if( t <= T || node->left < 0 )
1280 {
1281 node->complexity = 1;
1282 node->tree_risk = node->node_risk;
1283 node->tree_error = 0.;
1284 if( fold >= 0 )
1285 {
1286 node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1287 node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1288 }
1289 break;
1290 }
1291 nidx = node->left;
1292 }
1293
1294 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1295 nidx = pidx, pidx = w->wnodes[pidx].parent )
1296 {
1297 node = &w->wnodes[nidx];
1298 parent = &w->wnodes[pidx];
1299 parent->complexity += node->complexity;
1300 parent->tree_risk += node->tree_risk;
1301 parent->tree_error += node->tree_error;
1302
1303 parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1304 - parent->tree_risk)/(parent->complexity - 1);
1305 min_alpha = std::min( min_alpha, parent->alpha );
1306 }
1307
1308 if( pidx < 0 )
1309 break;
1310
1311 node = &w->wnodes[nidx];
1312 parent = &w->wnodes[pidx];
1313 parent->complexity = node->complexity;
1314 parent->tree_risk = node->tree_risk;
1315 parent->tree_error = node->tree_error;
1316 nidx = parent->right;
1317 }
1318
1319 return min_alpha;
1320 }
1321
cutTree(int root,double T,int fold,double min_alpha)1322 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1323 {
1324 int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1325 WNode* node = &w->wnodes[root];
1326 if( node->left < 0 )
1327 return true;
1328
1329 for(;;)
1330 {
1331 for(;;)
1332 {
1333 node = &w->wnodes[nidx];
1334 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1335 if( t <= T || node->left < 0 )
1336 break;
1337 if( node->alpha <= min_alpha + FLT_EPSILON )
1338 {
1339 if( fold >= 0 )
1340 w->cv_Tn[nidx*cv_n + fold] = T;
1341 else
1342 node->Tn = T;
1343 if( nidx == root )
1344 return true;
1345 break;
1346 }
1347 nidx = node->left;
1348 }
1349
1350 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1351 nidx = pidx, pidx = w->wnodes[pidx].parent )
1352 ;
1353
1354 if( pidx < 0 )
1355 break;
1356
1357 nidx = w->wnodes[pidx].right;
1358 }
1359
1360 return false;
1361 }
1362
predictTrees(const Range & range,const Mat & sample,int flags) const1363 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1364 {
1365 CV_Assert( sample.type() == CV_32F );
1366
1367 int predictType = flags & PREDICT_MASK;
1368 int nvars = (int)varIdx.size();
1369 if( nvars == 0 )
1370 nvars = (int)varType.size();
1371 int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1372 int catbufsize = ncats > 0 ? nvars : 0;
1373 AutoBuffer<int> buf(nclasses + catbufsize + 1);
1374 int* votes = buf;
1375 int* catbuf = votes + nclasses;
1376 const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1377 const uchar* vtype = &varType[0];
1378 const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1379 const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1380 const float* psample = sample.ptr<float>();
1381 const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1382 size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1383 double sum = 0.;
1384 int lastClassIdx = -1;
1385 const float MISSED_VAL = TrainData::missingValue();
1386
1387 for( i = 0; i < catbufsize; i++ )
1388 catbuf[i] = -1;
1389
1390 if( predictType == PREDICT_AUTO )
1391 {
1392 predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1393 PREDICT_SUM : PREDICT_MAX_VOTE;
1394 }
1395
1396 if( predictType == PREDICT_MAX_VOTE )
1397 {
1398 for( i = 0; i < nclasses; i++ )
1399 votes[i] = 0;
1400 }
1401
1402 for( int ridx = range.start; ridx < range.end; ridx++ )
1403 {
1404 int nidx = roots[ridx], prev = nidx, c = 0;
1405
1406 for(;;)
1407 {
1408 prev = nidx;
1409 const Node& node = nodes[nidx];
1410 if( node.split < 0 )
1411 break;
1412 const Split& split = splits[node.split];
1413 int vi = split.varIdx;
1414 int ci = cvidx ? cvidx[vi] : vi;
1415 float val = psample[ci*sstep];
1416 if( val == MISSED_VAL )
1417 {
1418 if( !missingSubstPtr )
1419 {
1420 nidx = node.defaultDir < 0 ? node.left : node.right;
1421 continue;
1422 }
1423 val = missingSubstPtr[vi];
1424 }
1425
1426 if( vtype[vi] == VAR_ORDERED )
1427 nidx = val <= split.c ? node.left : node.right;
1428 else
1429 {
1430 if( flags & PREPROCESSED_INPUT )
1431 c = cvRound(val);
1432 else
1433 {
1434 c = catbuf[ci];
1435 if( c < 0 )
1436 {
1437 int a = c = cofs[vi][0];
1438 int b = cofs[vi][1];
1439
1440 int ival = cvRound(val);
1441 if( ival != val )
1442 CV_Error( CV_StsBadArg,
1443 "one of input categorical variable is not an integer" );
1444
1445 while( a < b )
1446 {
1447 c = (a + b) >> 1;
1448 if( ival < cmap[c] )
1449 b = c;
1450 else if( ival > cmap[c] )
1451 a = c+1;
1452 else
1453 break;
1454 }
1455
1456 CV_Assert( c >= 0 && ival == cmap[c] );
1457
1458 c -= cofs[vi][0];
1459 catbuf[ci] = c;
1460 }
1461 const int* subset = &subsets[split.subsetOfs];
1462 unsigned u = c;
1463 nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1464 }
1465 }
1466 }
1467
1468 if( predictType == PREDICT_SUM )
1469 sum += nodes[prev].value;
1470 else
1471 {
1472 lastClassIdx = nodes[prev].classIdx;
1473 votes[lastClassIdx]++;
1474 }
1475 }
1476
1477 if( predictType == PREDICT_MAX_VOTE )
1478 {
1479 int best_idx = lastClassIdx;
1480 if( range.end - range.start > 1 )
1481 {
1482 best_idx = 0;
1483 for( i = 1; i < nclasses; i++ )
1484 if( votes[best_idx] < votes[i] )
1485 best_idx = i;
1486 }
1487 sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1488 }
1489
1490 return (float)sum;
1491 }
1492
1493
predict(InputArray _samples,OutputArray _results,int flags) const1494 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1495 {
1496 CV_Assert( !roots.empty() );
1497 Mat samples = _samples.getMat(), results;
1498 int i, nsamples = samples.rows;
1499 int rtype = CV_32F;
1500 bool needresults = _results.needed();
1501 float retval = 0.f;
1502 bool iscls = isClassifier();
1503 float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1504
1505 if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1506 rtype = CV_32S;
1507
1508 if( needresults )
1509 {
1510 _results.create(nsamples, 1, rtype);
1511 results = _results.getMat();
1512 }
1513 else
1514 nsamples = std::min(nsamples, 1);
1515
1516 for( i = 0; i < nsamples; i++ )
1517 {
1518 float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1519 if( needresults )
1520 {
1521 if( rtype == CV_32F )
1522 results.at<float>(i) = val;
1523 else
1524 results.at<int>(i) = cvRound(val);
1525 }
1526 if( i == 0 )
1527 retval = val;
1528 }
1529 return retval;
1530 }
1531
writeTrainingParams(FileStorage & fs) const1532 void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1533 {
1534 fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1535 fs << "max_categories" << params.getMaxCategories();
1536 fs << "regression_accuracy" << params.getRegressionAccuracy();
1537
1538 fs << "max_depth" << params.getMaxDepth();
1539 fs << "min_sample_count" << params.getMinSampleCount();
1540 fs << "cross_validation_folds" << params.getCVFolds();
1541
1542 if( params.getCVFolds() > 1 )
1543 fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1544
1545 if( !params.priors.empty() )
1546 fs << "priors" << params.priors;
1547 }
1548
writeParams(FileStorage & fs) const1549 void DTreesImpl::writeParams(FileStorage& fs) const
1550 {
1551 fs << "is_classifier" << isClassifier();
1552 fs << "var_all" << (int)varType.size();
1553 fs << "var_count" << getVarCount();
1554
1555 int ord_var_count = 0, cat_var_count = 0;
1556 int i, n = (int)varType.size();
1557 for( i = 0; i < n; i++ )
1558 if( varType[i] == VAR_ORDERED )
1559 ord_var_count++;
1560 else
1561 cat_var_count++;
1562 fs << "ord_var_count" << ord_var_count;
1563 fs << "cat_var_count" << cat_var_count;
1564
1565 fs << "training_params" << "{";
1566 writeTrainingParams(fs);
1567
1568 fs << "}";
1569
1570 if( !varIdx.empty() )
1571 {
1572 fs << "global_var_idx" << 1;
1573 fs << "var_idx" << varIdx;
1574 }
1575
1576 fs << "var_type" << varType;
1577
1578 if( !catOfs.empty() )
1579 fs << "cat_ofs" << catOfs;
1580 if( !catMap.empty() )
1581 fs << "cat_map" << catMap;
1582 if( !classLabels.empty() )
1583 fs << "class_labels" << classLabels;
1584 if( !missingSubst.empty() )
1585 fs << "missing_subst" << missingSubst;
1586 }
1587
writeSplit(FileStorage & fs,int splitidx) const1588 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1589 {
1590 const Split& split = splits[splitidx];
1591
1592 fs << "{:";
1593
1594 int vi = split.varIdx;
1595 fs << "var" << vi;
1596 fs << "quality" << split.quality;
1597
1598 if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1599 {
1600 int i, n = getCatCount(vi), to_right = 0;
1601 const int* subset = &subsets[split.subsetOfs];
1602 for( i = 0; i < n; i++ )
1603 to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1604
1605 // ad-hoc rule when to use inverse categorical split notation
1606 // to achieve more compact and clear representation
1607 int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1608
1609 fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1610
1611 for( i = 0; i < n; i++ )
1612 {
1613 int dir = CV_DTREE_CAT_DIR(i, subset);
1614 if( dir*default_dir < 0 )
1615 fs << i;
1616 }
1617
1618 fs << "]";
1619 }
1620 else
1621 fs << (!split.inversed ? "le" : "gt") << split.c;
1622
1623 fs << "}";
1624 }
1625
writeNode(FileStorage & fs,int nidx,int depth) const1626 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1627 {
1628 const Node& node = nodes[nidx];
1629 fs << "{";
1630 fs << "depth" << depth;
1631 fs << "value" << node.value;
1632
1633 if( _isClassifier )
1634 fs << "norm_class_idx" << node.classIdx;
1635
1636 if( node.split >= 0 )
1637 {
1638 fs << "splits" << "[";
1639
1640 for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1641 writeSplit( fs, splitidx );
1642
1643 fs << "]";
1644 }
1645
1646 fs << "}";
1647 }
1648
writeTree(FileStorage & fs,int root) const1649 void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1650 {
1651 fs << "nodes" << "[";
1652
1653 int nidx = root, pidx = 0, depth = 0;
1654 const Node *node = 0;
1655
1656 // traverse the tree and save all the nodes in depth-first order
1657 for(;;)
1658 {
1659 for(;;)
1660 {
1661 writeNode( fs, nidx, depth );
1662 node = &nodes[nidx];
1663 if( node->left < 0 )
1664 break;
1665 nidx = node->left;
1666 depth++;
1667 }
1668
1669 for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1670 nidx = pidx, pidx = nodes[pidx].parent )
1671 depth--;
1672
1673 if( pidx < 0 )
1674 break;
1675
1676 nidx = nodes[pidx].right;
1677 }
1678
1679 fs << "]";
1680 }
1681
write(FileStorage & fs) const1682 void DTreesImpl::write( FileStorage& fs ) const
1683 {
1684 writeParams(fs);
1685 writeTree(fs, roots[0]);
1686 }
1687
readParams(const FileNode & fn)1688 void DTreesImpl::readParams( const FileNode& fn )
1689 {
1690 _isClassifier = (int)fn["is_classifier"] != 0;
1691 /*int var_all = (int)fn["var_all"];
1692 int var_count = (int)fn["var_count"];
1693 int cat_var_count = (int)fn["cat_var_count"];
1694 int ord_var_count = (int)fn["ord_var_count"];*/
1695
1696 FileNode tparams_node = fn["training_params"];
1697
1698 TreeParams params0 = TreeParams();
1699
1700 if( !tparams_node.empty() ) // training parameters are not necessary
1701 {
1702 params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1703 params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1704 params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1705 params0.setMaxDepth((int)tparams_node["max_depth"]);
1706 params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1707 params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1708
1709 if( params0.getCVFolds() > 1 )
1710 {
1711 params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1712 }
1713
1714 tparams_node["priors"] >> params0.priors;
1715 }
1716
1717 readVectorOrMat(fn["var_idx"], varIdx);
1718 fn["var_type"] >> varType;
1719
1720 int format = 0;
1721 fn["format"] >> format;
1722 bool isLegacy = format < 3;
1723
1724 int varAll = (int)fn["var_all"];
1725 if (isLegacy && (int)varType.size() <= varAll)
1726 {
1727 std::vector<uchar> extendedTypes(varAll + 1, 0);
1728
1729 int i = 0, n;
1730 if (!varIdx.empty())
1731 {
1732 n = (int)varIdx.size();
1733 for (; i < n; ++i)
1734 {
1735 int var = varIdx[i];
1736 extendedTypes[var] = varType[i];
1737 }
1738 }
1739 else
1740 {
1741 n = (int)varType.size();
1742 for (; i < n; ++i)
1743 {
1744 extendedTypes[i] = varType[i];
1745 }
1746 }
1747 extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1748 extendedTypes.swap(varType);
1749 }
1750
1751 readVectorOrMat(fn["cat_map"], catMap);
1752
1753 if (isLegacy)
1754 {
1755 // generating "catOfs" from "cat_count"
1756 catOfs.clear();
1757 classLabels.clear();
1758 std::vector<int> counts;
1759 readVectorOrMat(fn["cat_count"], counts);
1760 unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1761 for (; i < size; ++i)
1762 {
1763 Vec2i newOffsets(0, 0);
1764 if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1765 {
1766 newOffsets[0] = curShift;
1767 curShift += counts[j];
1768 newOffsets[1] = curShift;
1769 ++j;
1770 }
1771 catOfs.push_back(newOffsets);
1772 }
1773 // other elements in "catMap" are "classLabels"
1774 if (curShift < catMap.size())
1775 {
1776 classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1777 catMap.erase(catMap.begin() + curShift, catMap.end());
1778 }
1779 }
1780 else
1781 {
1782 fn["cat_ofs"] >> catOfs;
1783 fn["missing_subst"] >> missingSubst;
1784 fn["class_labels"] >> classLabels;
1785 }
1786
1787 // init var mapping for node reading (var indexes or varIdx indexes)
1788 bool globalVarIdx = false;
1789 fn["global_var_idx"] >> globalVarIdx;
1790 if (globalVarIdx || varIdx.empty())
1791 setRangeVector(varMapping, (int)varType.size());
1792 else
1793 varMapping = varIdx;
1794
1795 initCompVarIdx();
1796 setDParams(params0);
1797 }
1798
readSplit(const FileNode & fn)1799 int DTreesImpl::readSplit( const FileNode& fn )
1800 {
1801 Split split;
1802
1803 int vi = (int)fn["var"];
1804 CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1805 vi = varMapping[vi]; // convert to varIdx if needed
1806 split.varIdx = vi;
1807
1808 if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1809 {
1810 int i, val, ssize = getSubsetSize(vi);
1811 split.subsetOfs = (int)subsets.size();
1812 for( i = 0; i < ssize; i++ )
1813 subsets.push_back(0);
1814 int* subset = &subsets[split.subsetOfs];
1815 FileNode fns = fn["in"];
1816 if( fns.empty() )
1817 {
1818 fns = fn["not_in"];
1819 split.inversed = true;
1820 }
1821
1822 if( fns.isInt() )
1823 {
1824 val = (int)fns;
1825 subset[val >> 5] |= 1 << (val & 31);
1826 }
1827 else
1828 {
1829 FileNodeIterator it = fns.begin();
1830 int n = (int)fns.size();
1831 for( i = 0; i < n; i++, ++it )
1832 {
1833 val = (int)*it;
1834 subset[val >> 5] |= 1 << (val & 31);
1835 }
1836 }
1837
1838 // for categorical splits we do not use inversed splits,
1839 // instead we inverse the variable set in the split
1840 if( split.inversed )
1841 {
1842 for( i = 0; i < ssize; i++ )
1843 subset[i] ^= -1;
1844 split.inversed = false;
1845 }
1846 }
1847 else
1848 {
1849 FileNode cmpNode = fn["le"];
1850 if( cmpNode.empty() )
1851 {
1852 cmpNode = fn["gt"];
1853 split.inversed = true;
1854 }
1855 split.c = (float)cmpNode;
1856 }
1857
1858 split.quality = (float)fn["quality"];
1859 splits.push_back(split);
1860
1861 return (int)(splits.size() - 1);
1862 }
1863
readNode(const FileNode & fn)1864 int DTreesImpl::readNode( const FileNode& fn )
1865 {
1866 Node node;
1867 node.value = (double)fn["value"];
1868
1869 if( _isClassifier )
1870 node.classIdx = (int)fn["norm_class_idx"];
1871
1872 FileNode sfn = fn["splits"];
1873 if( !sfn.empty() )
1874 {
1875 int i, n = (int)sfn.size(), prevsplit = -1;
1876 FileNodeIterator it = sfn.begin();
1877
1878 for( i = 0; i < n; i++, ++it )
1879 {
1880 int splitidx = readSplit(*it);
1881 if( splitidx < 0 )
1882 break;
1883 if( prevsplit < 0 )
1884 node.split = splitidx;
1885 else
1886 splits[prevsplit].next = splitidx;
1887 prevsplit = splitidx;
1888 }
1889 }
1890 nodes.push_back(node);
1891 return (int)(nodes.size() - 1);
1892 }
1893
readTree(const FileNode & fn)1894 int DTreesImpl::readTree( const FileNode& fn )
1895 {
1896 int i, n = (int)fn.size(), root = -1, pidx = -1;
1897 FileNodeIterator it = fn.begin();
1898
1899 for( i = 0; i < n; i++, ++it )
1900 {
1901 int nidx = readNode(*it);
1902 if( nidx < 0 )
1903 break;
1904 Node& node = nodes[nidx];
1905 node.parent = pidx;
1906 if( pidx < 0 )
1907 root = nidx;
1908 else
1909 {
1910 Node& parent = nodes[pidx];
1911 if( parent.left < 0 )
1912 parent.left = nidx;
1913 else
1914 parent.right = nidx;
1915 }
1916 if( node.split >= 0 )
1917 pidx = nidx;
1918 else
1919 {
1920 while( pidx >= 0 && nodes[pidx].right >= 0 )
1921 pidx = nodes[pidx].parent;
1922 }
1923 }
1924 roots.push_back(root);
1925 return root;
1926 }
1927
read(const FileNode & fn)1928 void DTreesImpl::read( const FileNode& fn )
1929 {
1930 clear();
1931 readParams(fn);
1932
1933 FileNode fnodes = fn["nodes"];
1934 CV_Assert( !fnodes.empty() );
1935 readTree(fnodes);
1936 }
1937
create()1938 Ptr<DTrees> DTrees::create()
1939 {
1940 return makePtr<DTreesImpl>();
1941 }
1942
1943 }
1944 }
1945
1946 /* End of file. */
1947