1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 #include <ctype.h>
45 
46 namespace cv {
47 namespace ml {
48 
49 using std::vector;
50 
TreeParams()51 TreeParams::TreeParams()
52 {
53     maxDepth = INT_MAX;
54     minSampleCount = 10;
55     regressionAccuracy = 0.01f;
56     useSurrogates = false;
57     maxCategories = 10;
58     CVFolds = 10;
59     use1SERule = true;
60     truncatePrunedTree = true;
61     priors = Mat();
62 }
63 
TreeParams(int _maxDepth,int _minSampleCount,double _regressionAccuracy,bool _useSurrogates,int _maxCategories,int _CVFolds,bool _use1SERule,bool _truncatePrunedTree,const Mat & _priors)64 TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
65                        double _regressionAccuracy, bool _useSurrogates,
66                        int _maxCategories, int _CVFolds,
67                        bool _use1SERule, bool _truncatePrunedTree,
68                        const Mat& _priors)
69 {
70     maxDepth = _maxDepth;
71     minSampleCount = _minSampleCount;
72     regressionAccuracy = (float)_regressionAccuracy;
73     useSurrogates = _useSurrogates;
74     maxCategories = _maxCategories;
75     CVFolds = _CVFolds;
76     use1SERule = _use1SERule;
77     truncatePrunedTree = _truncatePrunedTree;
78     priors = _priors;
79 }
80 
Node()81 DTrees::Node::Node()
82 {
83     classIdx = 0;
84     value = 0;
85     parent = left = right = split = defaultDir = -1;
86 }
87 
Split()88 DTrees::Split::Split()
89 {
90     varIdx = 0;
91     inversed = false;
92     quality = 0.f;
93     next = -1;
94     c = 0.f;
95     subsetOfs = 0;
96 }
97 
98 
WorkData(const Ptr<TrainData> & _data)99 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
100 {
101     data = _data;
102     vector<int> subsampleIdx;
103     Mat sidx0 = _data->getTrainSampleIdx();
104     if( !sidx0.empty() )
105     {
106         sidx0.copyTo(sidx);
107         std::sort(sidx.begin(), sidx.end());
108     }
109     else
110     {
111         int n = _data->getNSamples();
112         setRangeVector(sidx, n);
113     }
114 
115     maxSubsetSize = 0;
116 }
117 
DTreesImpl()118 DTreesImpl::DTreesImpl() {}
~DTreesImpl()119 DTreesImpl::~DTreesImpl() {}
clear()120 void DTreesImpl::clear()
121 {
122     varIdx.clear();
123     compVarIdx.clear();
124     varType.clear();
125     catOfs.clear();
126     catMap.clear();
127     roots.clear();
128     nodes.clear();
129     splits.clear();
130     subsets.clear();
131     classLabels.clear();
132 
133     w.release();
134     _isClassifier = false;
135 }
136 
startTraining(const Ptr<TrainData> & data,int)137 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
138 {
139     clear();
140     w = makePtr<WorkData>(data);
141 
142     Mat vtype = data->getVarType();
143     vtype.copyTo(varType);
144 
145     data->getCatOfs().copyTo(catOfs);
146     data->getCatMap().copyTo(catMap);
147     data->getDefaultSubstValues().copyTo(missingSubst);
148 
149     int nallvars = data->getNAllVars();
150 
151     Mat vidx0 = data->getVarIdx();
152     if( !vidx0.empty() )
153         vidx0.copyTo(varIdx);
154     else
155         setRangeVector(varIdx, nallvars);
156 
157     initCompVarIdx();
158 
159     w->maxSubsetSize = 0;
160 
161     int i, nvars = (int)varIdx.size();
162     for( i = 0; i < nvars; i++ )
163         w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
164 
165     w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
166 
167     data->getSampleWeights().copyTo(w->sample_weights);
168 
169     _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
170 
171     if( _isClassifier )
172     {
173         data->getNormCatResponses().copyTo(w->cat_responses);
174         data->getClassLabels().copyTo(classLabels);
175         int nclasses = (int)classLabels.size();
176 
177         Mat class_weights = params.priors;
178         if( !class_weights.empty() )
179         {
180             if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
181             {
182                 Mat temp;
183                 class_weights.convertTo(temp, CV_64F);
184                 class_weights = temp;
185             }
186             CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
187 
188             int nsamples = (int)w->cat_responses.size();
189             const double* cw = class_weights.ptr<double>();
190             CV_Assert( (int)w->sample_weights.size() == nsamples );
191 
192             for( i = 0; i < nsamples; i++ )
193             {
194                 int ci = w->cat_responses[i];
195                 CV_Assert( 0 <= ci && ci < nclasses );
196                 w->sample_weights[i] *= cw[ci];
197             }
198         }
199     }
200     else
201         data->getResponses().copyTo(w->ord_responses);
202 }
203 
204 
initCompVarIdx()205 void DTreesImpl::initCompVarIdx()
206 {
207     int nallvars = (int)varType.size();
208     compVarIdx.assign(nallvars, -1);
209     int i, nvars = (int)varIdx.size(), prevIdx = -1;
210     for( i = 0; i < nvars; i++ )
211     {
212         int vi = varIdx[i];
213         CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
214         prevIdx = vi;
215         compVarIdx[vi] = i;
216     }
217 }
218 
endTraining()219 void DTreesImpl::endTraining()
220 {
221     w.release();
222 }
223 
train(const Ptr<TrainData> & trainData,int flags)224 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
225 {
226     startTraining(trainData, flags);
227     bool ok = addTree( w->sidx ) >= 0;
228     w.release();
229     endTraining();
230     return ok;
231 }
232 
getActiveVars()233 const vector<int>& DTreesImpl::getActiveVars()
234 {
235     return varIdx;
236 }
237 
addTree(const vector<int> & sidx)238 int DTreesImpl::addTree(const vector<int>& sidx )
239 {
240     size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
241 
242     w->wnodes.reserve(n);
243     w->wsplits.reserve(n);
244     w->wsubsets.reserve(n*w->maxSubsetSize);
245     w->wnodes.clear();
246     w->wsplits.clear();
247     w->wsubsets.clear();
248 
249     int cv_n = params.getCVFolds();
250 
251     if( cv_n > 0 )
252     {
253         w->cv_Tn.resize(n*cv_n);
254         w->cv_node_error.resize(n*cv_n);
255         w->cv_node_risk.resize(n*cv_n);
256     }
257 
258     // build the tree recursively
259     int w_root = addNodeAndTrySplit(-1, sidx);
260     int maxdepth = INT_MAX;//pruneCV(root);
261 
262     int w_nidx = w_root, pidx = -1, depth = 0;
263     int root = (int)nodes.size();
264 
265     for(;;)
266     {
267         const WNode& wnode = w->wnodes[w_nidx];
268         Node node;
269         node.parent = pidx;
270         node.classIdx = wnode.class_idx;
271         node.value = wnode.value;
272         node.defaultDir = wnode.defaultDir;
273 
274         int wsplit_idx = wnode.split;
275         if( wsplit_idx >= 0 )
276         {
277             const WSplit& wsplit = w->wsplits[wsplit_idx];
278             Split split;
279             split.c = wsplit.c;
280             split.quality = wsplit.quality;
281             split.inversed = wsplit.inversed;
282             split.varIdx = wsplit.varIdx;
283             split.subsetOfs = -1;
284             if( wsplit.subsetOfs >= 0 )
285             {
286                 int ssize = getSubsetSize(split.varIdx);
287                 split.subsetOfs = (int)subsets.size();
288                 subsets.resize(split.subsetOfs + ssize);
289                 // This check verifies that subsets index is in the correct range
290                 // as in case ssize == 0 no real resize performed.
291                 // Thus memory kept safe.
292                 // Also this skips useless memcpy call when size parameter is zero
293                 if(ssize > 0)
294                 {
295                     memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
296                 }
297             }
298             node.split = (int)splits.size();
299             splits.push_back(split);
300         }
301         int nidx = (int)nodes.size();
302         nodes.push_back(node);
303         if( pidx >= 0 )
304         {
305             int w_pidx = w->wnodes[w_nidx].parent;
306             if( w->wnodes[w_pidx].left == w_nidx )
307             {
308                 nodes[pidx].left = nidx;
309             }
310             else
311             {
312                 CV_Assert(w->wnodes[w_pidx].right == w_nidx);
313                 nodes[pidx].right = nidx;
314             }
315         }
316 
317         if( wnode.left >= 0 && depth+1 < maxdepth )
318         {
319             w_nidx = wnode.left;
320             pidx = nidx;
321             depth++;
322         }
323         else
324         {
325             int w_pidx = wnode.parent;
326             while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
327             {
328                 w_nidx = w_pidx;
329                 w_pidx = w->wnodes[w_pidx].parent;
330                 nidx = pidx;
331                 pidx = nodes[pidx].parent;
332                 depth--;
333             }
334 
335             if( w_pidx < 0 )
336                 break;
337 
338             w_nidx = w->wnodes[w_pidx].right;
339             CV_Assert( w_nidx >= 0 );
340         }
341     }
342     roots.push_back(root);
343     return root;
344 }
345 
setDParams(const TreeParams & _params)346 void DTreesImpl::setDParams(const TreeParams& _params)
347 {
348     params = _params;
349 }
350 
addNodeAndTrySplit(int parent,const vector<int> & sidx)351 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
352 {
353     w->wnodes.push_back(WNode());
354     int nidx = (int)(w->wnodes.size() - 1);
355     WNode& node = w->wnodes.back();
356 
357     node.parent = parent;
358     node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
359     int nfolds = params.getCVFolds();
360 
361     if( nfolds > 0 )
362     {
363         w->cv_Tn.resize((nidx+1)*nfolds);
364         w->cv_node_error.resize((nidx+1)*nfolds);
365         w->cv_node_risk.resize((nidx+1)*nfolds);
366     }
367 
368     int i, n = node.sample_count = (int)sidx.size();
369     bool can_split = true;
370     vector<int> sleft, sright;
371 
372     calcValue( nidx, sidx );
373 
374     if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
375         can_split = false;
376     else if( _isClassifier )
377     {
378         const int* responses = &w->cat_responses[0];
379         const int* s = &sidx[0];
380         int first = responses[s[0]];
381         for( i = 1; i < n; i++ )
382             if( responses[s[i]] != first )
383                 break;
384         if( i == n )
385             can_split = false;
386     }
387     else
388     {
389         if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
390             can_split = false;
391     }
392 
393     if( can_split )
394         node.split = findBestSplit( sidx );
395 
396     //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
397 
398     if( node.split >= 0 )
399     {
400         node.defaultDir = calcDir( node.split, sidx, sleft, sright );
401         if( params.useSurrogates )
402             CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
403 
404         int left = addNodeAndTrySplit( nidx, sleft );
405         int right = addNodeAndTrySplit( nidx, sright );
406         w->wnodes[nidx].left = left;
407         w->wnodes[nidx].right = right;
408         CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
409     }
410 
411     return nidx;
412 }
413 
findBestSplit(const vector<int> & _sidx)414 int DTreesImpl::findBestSplit( const vector<int>& _sidx )
415 {
416     const vector<int>& activeVars = getActiveVars();
417     int splitidx = -1;
418     int vi_, nv = (int)activeVars.size();
419     AutoBuffer<int> buf(w->maxSubsetSize*2);
420     int *subset = buf, *best_subset = subset + w->maxSubsetSize;
421     WSplit split, best_split;
422     best_split.quality = 0.;
423 
424     for( vi_ = 0; vi_ < nv; vi_++ )
425     {
426         int vi = activeVars[vi_];
427         if( varType[vi] == VAR_CATEGORICAL )
428         {
429             if( _isClassifier )
430                 split = findSplitCatClass(vi, _sidx, 0, subset);
431             else
432                 split = findSplitCatReg(vi, _sidx, 0, subset);
433         }
434         else
435         {
436             if( _isClassifier )
437                 split = findSplitOrdClass(vi, _sidx, 0);
438             else
439                 split = findSplitOrdReg(vi, _sidx, 0);
440         }
441         if( split.quality > best_split.quality )
442         {
443             best_split = split;
444             std::swap(subset, best_subset);
445         }
446     }
447 
448     if( best_split.quality > 0 )
449     {
450         int best_vi = best_split.varIdx;
451         CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
452         int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
453         w->wsubsets.resize(prevsz + ssize);
454         for( i = 0; i < ssize; i++ )
455             w->wsubsets[prevsz + i] = best_subset[i];
456         best_split.subsetOfs = prevsz;
457         w->wsplits.push_back(best_split);
458         splitidx = (int)(w->wsplits.size()-1);
459     }
460 
461     return splitidx;
462 }
463 
calcValue(int nidx,const vector<int> & _sidx)464 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
465 {
466     WNode* node = &w->wnodes[nidx];
467     int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
468     int m = (int)classLabels.size();
469 
470     cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
471 
472     if( cv_n > 0 )
473     {
474         size_t sz = w->cv_Tn.size();
475         w->cv_Tn.resize(sz + cv_n);
476         w->cv_node_risk.resize(sz + cv_n);
477         w->cv_node_error.resize(sz + cv_n);
478     }
479 
480     if( _isClassifier )
481     {
482         // in case of classification tree:
483         //  * node value is the label of the class that has the largest weight in the node.
484         //  * node risk is the weighted number of misclassified samples,
485         //  * j-th cross-validation fold value and risk are calculated as above,
486         //    but using the samples with cv_labels(*)!=j.
487         //  * j-th cross-validation fold error is calculated as the weighted number of
488         //    misclassified samples with cv_labels(*)==j.
489 
490         // compute the number of instances of each class
491         double* cls_count = buf;
492         double* cv_cls_count = cls_count + m;
493 
494         double max_val = -1, total_weight = 0;
495         int max_k = -1;
496 
497         for( k = 0; k < m; k++ )
498             cls_count[k] = 0;
499 
500         if( cv_n == 0 )
501         {
502             for( i = 0; i < n; i++ )
503             {
504                 int si = _sidx[i];
505                 cls_count[w->cat_responses[si]] += w->sample_weights[si];
506             }
507         }
508         else
509         {
510             for( j = 0; j < cv_n; j++ )
511                 for( k = 0; k < m; k++ )
512                     cv_cls_count[j*m + k] = 0;
513 
514             for( i = 0; i < n; i++ )
515             {
516                 int si = _sidx[i];
517                 j = w->cv_labels[si]; k = w->cat_responses[si];
518                 cv_cls_count[j*m + k] += w->sample_weights[si];
519             }
520 
521             for( j = 0; j < cv_n; j++ )
522                 for( k = 0; k < m; k++ )
523                     cls_count[k] += cv_cls_count[j*m + k];
524         }
525 
526         for( k = 0; k < m; k++ )
527         {
528             double val = cls_count[k];
529             total_weight += val;
530             if( max_val < val )
531             {
532                 max_val = val;
533                 max_k = k;
534             }
535         }
536 
537         node->class_idx = max_k;
538         node->value = classLabels[max_k];
539         node->node_risk = total_weight - max_val;
540 
541         for( j = 0; j < cv_n; j++ )
542         {
543             double sum_k = 0, sum = 0, max_val_k = 0;
544             max_val = -1; max_k = -1;
545 
546             for( k = 0; k < m; k++ )
547             {
548                 double val_k = cv_cls_count[j*m + k];
549                 double val = cls_count[k] - val_k;
550                 sum_k += val_k;
551                 sum += val;
552                 if( max_val < val )
553                 {
554                     max_val = val;
555                     max_val_k = val_k;
556                     max_k = k;
557                 }
558             }
559 
560             w->cv_Tn[nidx*cv_n + j] = INT_MAX;
561             w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
562             w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
563         }
564     }
565     else
566     {
567         // in case of regression tree:
568         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
569         //    n is the number of samples in the node.
570         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
571         //  * j-th cross-validation fold value and risk are calculated as above,
572         //    but using the samples with cv_labels(*)!=j.
573         //  * j-th cross-validation fold error is calculated
574         //    using samples with cv_labels(*)==j as the test subset:
575         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
576         //    where node_value_j is the node value calculated
577         //    as described in the previous bullet, and summation is done
578         //    over the samples with cv_labels(*)==j.
579         double sum = 0, sum2 = 0, sumw = 0;
580 
581         if( cv_n == 0 )
582         {
583             for( i = 0; i < n; i++ )
584             {
585                 int si = _sidx[i];
586                 double wval = w->sample_weights[si];
587                 double t = w->ord_responses[si];
588                 sum += t*wval;
589                 sum2 += t*t*wval;
590                 sumw += wval;
591             }
592         }
593         else
594         {
595             double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
596             double* cv_count = (double*)(cv_sum2 + cv_n);
597 
598             for( j = 0; j < cv_n; j++ )
599             {
600                 cv_sum[j] = cv_sum2[j] = 0.;
601                 cv_count[j] = 0;
602             }
603 
604             for( i = 0; i < n; i++ )
605             {
606                 int si = _sidx[i];
607                 j = w->cv_labels[si];
608                 double wval = w->sample_weights[si];
609                 double t = w->ord_responses[si];
610                 cv_sum[j] += t*wval;
611                 cv_sum2[j] += t*t*wval;
612                 cv_count[j] += wval;
613             }
614 
615             for( j = 0; j < cv_n; j++ )
616             {
617                 sum += cv_sum[j];
618                 sum2 += cv_sum2[j];
619                 sumw += cv_count[j];
620             }
621 
622             for( j = 0; j < cv_n; j++ )
623             {
624                 double s = sum - cv_sum[j], si = sum - s;
625                 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
626                 double c = cv_count[j], ci = sumw - c;
627                 double r = si/std::max(ci, DBL_EPSILON);
628                 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
629                 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
630                 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
631             }
632         }
633 
634         node->node_risk = sum2 - (sum/sumw)*sum;
635         node->value = sum/sumw;
636     }
637 }
638 
findSplitOrdClass(int vi,const vector<int> & _sidx,double initQuality)639 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
640 {
641     const double epsilon = FLT_EPSILON*2;
642     int n = (int)_sidx.size();
643     int m = (int)classLabels.size();
644 
645     cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
646     const int* sidx = &_sidx[0];
647     const int* responses = &w->cat_responses[0];
648     const double* weights = &w->sample_weights[0];
649     double* lcw = (double*)(uchar*)buf;
650     double* rcw = lcw + m;
651     float* values = (float*)(rcw + m);
652     int* sorted_idx = (int*)(values + n);
653     int i, best_i = -1;
654     double best_val = initQuality;
655 
656     for( i = 0; i < m; i++ )
657         lcw[i] = rcw[i] = 0.;
658 
659     w->data->getValues( vi, _sidx, values );
660 
661     for( i = 0; i < n; i++ )
662     {
663         sorted_idx[i] = i;
664         int si = sidx[i];
665         rcw[responses[si]] += weights[si];
666     }
667 
668     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
669 
670     double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
671     for( i = 0; i < m; i++ )
672     {
673         double wval = rcw[i];
674         R += wval;
675         rsum2 += wval*wval;
676     }
677 
678     for( i = 0; i < n - 1; i++ )
679     {
680         int curr = sorted_idx[i];
681         int next = sorted_idx[i+1];
682         int si = sidx[curr];
683         double wval = weights[si], w2 = wval*wval;
684         L += wval; R -= wval;
685         int idx = responses[si];
686         double lv = lcw[idx], rv = rcw[idx];
687         lsum2 += 2*lv*wval + w2;
688         rsum2 -= 2*rv*wval - w2;
689         lcw[idx] = lv + wval; rcw[idx] = rv - wval;
690 
691         if( values[curr] + epsilon < values[next] )
692         {
693             double val = (lsum2*R + rsum2*L)/(L*R);
694             if( best_val < val )
695             {
696                 best_val = val;
697                 best_i = i;
698             }
699         }
700     }
701 
702     WSplit split;
703     if( best_i >= 0 )
704     {
705         split.varIdx = vi;
706         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
707         split.inversed = false;
708         split.quality = (float)best_val;
709     }
710     return split;
711 }
712 
713 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
clusterCategories(const double * vectors,int n,int m,double * csums,int k,int * labels)714 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
715 {
716     int iters = 0, max_iters = 100;
717     int i, j, idx;
718     cv::AutoBuffer<double> buf(n + k);
719     double *v_weights = buf, *c_weights = buf + n;
720     bool modified = true;
721     RNG r((uint64)-1);
722 
723     // assign labels randomly
724     for( i = 0; i < n; i++ )
725     {
726         double sum = 0;
727         const double* v = vectors + i*m;
728         labels[i] = i < k ? i : r.uniform(0, k);
729 
730         // compute weight of each vector
731         for( j = 0; j < m; j++ )
732             sum += v[j];
733         v_weights[i] = sum ? 1./sum : 0.;
734     }
735 
736     for( i = 0; i < n; i++ )
737     {
738         int i1 = r.uniform(0, n);
739         int i2 = r.uniform(0, n);
740         std::swap( labels[i1], labels[i2] );
741     }
742 
743     for( iters = 0; iters <= max_iters; iters++ )
744     {
745         // calculate csums
746         for( i = 0; i < k; i++ )
747         {
748             for( j = 0; j < m; j++ )
749                 csums[i*m + j] = 0;
750         }
751 
752         for( i = 0; i < n; i++ )
753         {
754             const double* v = vectors + i*m;
755             double* s = csums + labels[i]*m;
756             for( j = 0; j < m; j++ )
757                 s[j] += v[j];
758         }
759 
760         // exit the loop here, when we have up-to-date csums
761         if( iters == max_iters || !modified )
762             break;
763 
764         modified = false;
765 
766         // calculate weight of each cluster
767         for( i = 0; i < k; i++ )
768         {
769             const double* s = csums + i*m;
770             double sum = 0;
771             for( j = 0; j < m; j++ )
772                 sum += s[j];
773             c_weights[i] = sum ? 1./sum : 0;
774         }
775 
776         // now for each vector determine the closest cluster
777         for( i = 0; i < n; i++ )
778         {
779             const double* v = vectors + i*m;
780             double alpha = v_weights[i];
781             double min_dist2 = DBL_MAX;
782             int min_idx = -1;
783 
784             for( idx = 0; idx < k; idx++ )
785             {
786                 const double* s = csums + idx*m;
787                 double dist2 = 0., beta = c_weights[idx];
788                 for( j = 0; j < m; j++ )
789                 {
790                     double t = v[j]*alpha - s[j]*beta;
791                     dist2 += t*t;
792                 }
793                 if( min_dist2 > dist2 )
794                 {
795                     min_dist2 = dist2;
796                     min_idx = idx;
797                 }
798             }
799 
800             if( min_idx != labels[i] )
801                 modified = true;
802             labels[i] = min_idx;
803         }
804     }
805 }
806 
findSplitCatClass(int vi,const vector<int> & _sidx,double initQuality,int * subset)807 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
808                                                   double initQuality, int* subset )
809 {
810     int _mi = getCatCount(vi), mi = _mi;
811     int n = (int)_sidx.size();
812     int m = (int)classLabels.size();
813 
814     int base_size = m*(3 + mi) + mi + 1;
815     if( m > 2 && mi > params.getMaxCategories() )
816         base_size += m*std::min(params.getMaxCategories(), n) + mi;
817     else
818         base_size += mi;
819     AutoBuffer<double> buf(base_size + n);
820 
821     double* lc = (double*)buf;
822     double* rc = lc + m;
823     double* _cjk = rc + m*2, *cjk = _cjk;
824     double* c_weights = cjk + m*mi;
825 
826     int* labels = (int*)(buf + base_size);
827     w->data->getNormCatValues(vi, _sidx, labels);
828     const int* responses = &w->cat_responses[0];
829     const double* weights = &w->sample_weights[0];
830 
831     int* cluster_labels = 0;
832     double** dbl_ptr = 0;
833     int i, j, k, si, idx;
834     double L = 0, R = 0;
835     double best_val = initQuality;
836     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
837 
838     // init array of counters:
839     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
840     for( j = -1; j < mi; j++ )
841         for( k = 0; k < m; k++ )
842             cjk[j*m + k] = 0;
843 
844     for( i = 0; i < n; i++ )
845     {
846         si = _sidx[i];
847         j = labels[i];
848         k = responses[si];
849         cjk[j*m + k] += weights[si];
850     }
851 
852     if( m > 2 )
853     {
854         if( mi > params.getMaxCategories() )
855         {
856             mi = std::min(params.getMaxCategories(), n);
857             cjk = c_weights + _mi;
858             cluster_labels = (int*)(cjk + m*mi);
859             clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
860         }
861         subset_i = 1;
862         subset_n = 1 << mi;
863     }
864     else
865     {
866         assert( m == 2 );
867         dbl_ptr = (double**)(c_weights + _mi);
868         for( j = 0; j < mi; j++ )
869             dbl_ptr[j] = cjk + j*2 + 1;
870         std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
871         subset_i = 0;
872         subset_n = mi;
873     }
874 
875     for( k = 0; k < m; k++ )
876     {
877         double sum = 0;
878         for( j = 0; j < mi; j++ )
879             sum += cjk[j*m + k];
880         CV_Assert(sum > 0);
881         rc[k] = sum;
882         lc[k] = 0;
883     }
884 
885     for( j = 0; j < mi; j++ )
886     {
887         double sum = 0;
888         for( k = 0; k < m; k++ )
889             sum += cjk[j*m + k];
890         c_weights[j] = sum;
891         R += c_weights[j];
892     }
893 
894     for( ; subset_i < subset_n; subset_i++ )
895     {
896         double lsum2 = 0, rsum2 = 0;
897 
898         if( m == 2 )
899             idx = (int)(dbl_ptr[subset_i] - cjk)/2;
900         else
901         {
902             int graycode = (subset_i>>1)^subset_i;
903             int diff = graycode ^ prevcode;
904 
905             // determine index of the changed bit.
906             Cv32suf u;
907             idx = diff >= (1 << 16) ? 16 : 0;
908             u.f = (float)(((diff >> 16) | diff) & 65535);
909             idx += (u.i >> 23) - 127;
910             subtract = graycode < prevcode;
911             prevcode = graycode;
912         }
913 
914         double* crow = cjk + idx*m;
915         double weight = c_weights[idx];
916         if( weight < FLT_EPSILON )
917             continue;
918 
919         if( !subtract )
920         {
921             for( k = 0; k < m; k++ )
922             {
923                 double t = crow[k];
924                 double lval = lc[k] + t;
925                 double rval = rc[k] - t;
926                 lsum2 += lval*lval;
927                 rsum2 += rval*rval;
928                 lc[k] = lval; rc[k] = rval;
929             }
930             L += weight;
931             R -= weight;
932         }
933         else
934         {
935             for( k = 0; k < m; k++ )
936             {
937                 double t = crow[k];
938                 double lval = lc[k] - t;
939                 double rval = rc[k] + t;
940                 lsum2 += lval*lval;
941                 rsum2 += rval*rval;
942                 lc[k] = lval; rc[k] = rval;
943             }
944             L -= weight;
945             R += weight;
946         }
947 
948         if( L > FLT_EPSILON && R > FLT_EPSILON )
949         {
950             double val = (lsum2*R + rsum2*L)/(L*R);
951             if( best_val < val )
952             {
953                 best_val = val;
954                 best_subset = subset_i;
955             }
956         }
957     }
958 
959     WSplit split;
960     if( best_subset >= 0 )
961     {
962         split.varIdx = vi;
963         split.quality = (float)best_val;
964         memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
965         if( m == 2 )
966         {
967             for( i = 0; i <= best_subset; i++ )
968             {
969                 idx = (int)(dbl_ptr[i] - cjk) >> 1;
970                 subset[idx >> 5] |= 1 << (idx & 31);
971             }
972         }
973         else
974         {
975             for( i = 0; i < _mi; i++ )
976             {
977                 idx = cluster_labels ? cluster_labels[i] : i;
978                 if( best_subset & (1 << idx) )
979                     subset[i >> 5] |= 1 << (i & 31);
980             }
981         }
982     }
983     return split;
984 }
985 
findSplitOrdReg(int vi,const vector<int> & _sidx,double initQuality)986 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
987 {
988     const float epsilon = FLT_EPSILON*2;
989     const double* weights = &w->sample_weights[0];
990     int n = (int)_sidx.size();
991 
992     AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
993 
994     float* values = (float*)(uchar*)buf;
995     int* sorted_idx = (int*)(values + n);
996     w->data->getValues(vi, _sidx, values);
997     const double* responses = &w->ord_responses[0];
998 
999     int i, si, best_i = -1;
1000     double L = 0, R = 0;
1001     double best_val = initQuality, lsum = 0, rsum = 0;
1002 
1003     for( i = 0; i < n; i++ )
1004     {
1005         sorted_idx[i] = i;
1006         si = _sidx[i];
1007         R += weights[si];
1008         rsum += weights[si]*responses[si];
1009     }
1010 
1011     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1012 
1013     // find the optimal split
1014     for( i = 0; i < n - 1; i++ )
1015     {
1016         int curr = sorted_idx[i];
1017         int next = sorted_idx[i+1];
1018         si = _sidx[curr];
1019         double wval = weights[si];
1020         double t = responses[si]*wval;
1021         L += wval; R -= wval;
1022         lsum += t; rsum -= t;
1023 
1024         if( values[curr] + epsilon < values[next] )
1025         {
1026             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1027             if( best_val < val )
1028             {
1029                 best_val = val;
1030                 best_i = i;
1031             }
1032         }
1033     }
1034 
1035     WSplit split;
1036     if( best_i >= 0 )
1037     {
1038         split.varIdx = vi;
1039         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1040         split.inversed = false;
1041         split.quality = (float)best_val;
1042     }
1043     return split;
1044 }
1045 
findSplitCatReg(int vi,const vector<int> & _sidx,double initQuality,int * subset)1046 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1047                                                 double initQuality, int* subset )
1048 {
1049     const double* weights = &w->sample_weights[0];
1050     const double* responses = &w->ord_responses[0];
1051     int n = (int)_sidx.size();
1052     int mi = getCatCount(vi);
1053 
1054     AutoBuffer<double> buf(3*mi + 3 + n);
1055     double* sum = (double*)buf + 1;
1056     double* counts = sum + mi + 1;
1057     double** sum_ptr = (double**)(counts + mi);
1058     int* cat_labels = (int*)(sum_ptr + mi);
1059 
1060     w->data->getNormCatValues(vi, _sidx, cat_labels);
1061 
1062     double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1063     int i, si, best_subset = -1, subset_i;
1064 
1065     for( i = -1; i < mi; i++ )
1066         sum[i] = counts[i] = 0;
1067 
1068     // calculate sum response and weight of each category of the input var
1069     for( i = 0; i < n; i++ )
1070     {
1071         int idx = cat_labels[i];
1072         si = _sidx[i];
1073         double wval = weights[si];
1074         sum[idx] += responses[si]*wval;
1075         counts[idx] += wval;
1076     }
1077 
1078     // calculate average response in each category
1079     for( i = 0; i < mi; i++ )
1080     {
1081         R += counts[i];
1082         rsum += sum[i];
1083         sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1084         sum_ptr[i] = sum + i;
1085     }
1086 
1087     std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1088 
1089     // revert back to unnormalized sums
1090     // (there should be a very little loss in accuracy)
1091     for( i = 0; i < mi; i++ )
1092         sum[i] *= counts[i];
1093 
1094     for( subset_i = 0; subset_i < mi-1; subset_i++ )
1095     {
1096         int idx = (int)(sum_ptr[subset_i] - sum);
1097         double ni = counts[idx];
1098 
1099         if( ni > FLT_EPSILON )
1100         {
1101             double s = sum[idx];
1102             lsum += s; L += ni;
1103             rsum -= s; R -= ni;
1104 
1105             if( L > FLT_EPSILON && R > FLT_EPSILON )
1106             {
1107                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1108                 if( best_val < val )
1109                 {
1110                     best_val = val;
1111                     best_subset = subset_i;
1112                 }
1113             }
1114         }
1115     }
1116 
1117     WSplit split;
1118     if( best_subset >= 0 )
1119     {
1120         split.varIdx = vi;
1121         split.quality = (float)best_val;
1122         memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1123         for( i = 0; i <= best_subset; i++ )
1124         {
1125             int idx = (int)(sum_ptr[i] - sum);
1126             subset[idx >> 5] |= 1 << (idx & 31);
1127         }
1128     }
1129     return split;
1130 }
1131 
calcDir(int splitidx,const vector<int> & _sidx,vector<int> & _sleft,vector<int> & _sright)1132 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1133                          vector<int>& _sleft, vector<int>& _sright )
1134 {
1135     WSplit split = w->wsplits[splitidx];
1136     int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1137     _sleft.reserve(n);
1138     _sright.reserve(n);
1139     _sleft.clear();
1140     _sright.clear();
1141 
1142     AutoBuffer<float> buf(n);
1143     int mi = getCatCount(vi);
1144     double wleft = 0, wright = 0;
1145     const double* weights = &w->sample_weights[0];
1146 
1147     if( mi <= 0 ) // split on an ordered variable
1148     {
1149         float c = split.c;
1150         float* values = buf;
1151         w->data->getValues(vi, _sidx, values);
1152 
1153         for( i = 0; i < n; i++ )
1154         {
1155             si = _sidx[i];
1156             if( values[i] <= c )
1157             {
1158                 _sleft.push_back(si);
1159                 wleft += weights[si];
1160             }
1161             else
1162             {
1163                 _sright.push_back(si);
1164                 wright += weights[si];
1165             }
1166         }
1167     }
1168     else
1169     {
1170         const int* subset = &w->wsubsets[split.subsetOfs];
1171         int* cat_labels = (int*)(float*)buf;
1172         w->data->getNormCatValues(vi, _sidx, cat_labels);
1173 
1174         for( i = 0; i < n; i++ )
1175         {
1176             si = _sidx[i];
1177             unsigned u = cat_labels[i];
1178             if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1179             {
1180                 _sleft.push_back(si);
1181                 wleft += weights[si];
1182             }
1183             else
1184             {
1185                 _sright.push_back(si);
1186                 wright += weights[si];
1187             }
1188         }
1189     }
1190     CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1191     return wleft > wright ? -1 : 1;
1192 }
1193 
pruneCV(int root)1194 int DTreesImpl::pruneCV( int root )
1195 {
1196     vector<double> ab;
1197 
1198     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1199     // 2. choose the best tree index (if need, apply 1SE rule).
1200     // 3. store the best index and cut the branches.
1201 
1202     int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1203     // currently, 1SE for regression is not implemented
1204     bool use_1se = params.use1SERule != 0 && _isClassifier;
1205     double min_err = 0, min_err_se = 0;
1206     int min_idx = -1;
1207 
1208     // build the main tree sequence, calculate alpha's
1209     for(;;tree_count++)
1210     {
1211         double min_alpha = updateTreeRNC(root, tree_count, -1);
1212         if( cutTree(root, tree_count, -1, min_alpha) )
1213             break;
1214 
1215         ab.push_back(min_alpha);
1216     }
1217 
1218     if( tree_count > 0 )
1219     {
1220         ab[0] = 0.;
1221 
1222         for( ti = 1; ti < tree_count-1; ti++ )
1223             ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1224         ab[tree_count-1] = DBL_MAX*0.5;
1225 
1226         Mat err_jk(cv_n, tree_count, CV_64F);
1227 
1228         for( j = 0; j < cv_n; j++ )
1229         {
1230             int tj = 0, tk = 0;
1231             for( ; tj < tree_count; tj++ )
1232             {
1233                 double min_alpha = updateTreeRNC(root, tj, j);
1234                 if( cutTree(root, tj, j, min_alpha) )
1235                     min_alpha = DBL_MAX;
1236 
1237                 for( ; tk < tree_count; tk++ )
1238                 {
1239                     if( ab[tk] > min_alpha )
1240                         break;
1241                     err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1242                 }
1243             }
1244         }
1245 
1246         for( ti = 0; ti < tree_count; ti++ )
1247         {
1248             double sum_err = 0;
1249             for( j = 0; j < cv_n; j++ )
1250                 sum_err += err_jk.at<double>(j, ti);
1251             if( ti == 0 || sum_err < min_err )
1252             {
1253                 min_err = sum_err;
1254                 min_idx = ti;
1255                 if( use_1se )
1256                     min_err_se = sqrt( sum_err*(n - sum_err) );
1257             }
1258             else if( sum_err < min_err + min_err_se )
1259                 min_idx = ti;
1260         }
1261     }
1262 
1263     return min_idx;
1264 }
1265 
updateTreeRNC(int root,double T,int fold)1266 double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1267 {
1268     int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1269     double min_alpha = DBL_MAX;
1270 
1271     for(;;)
1272     {
1273         WNode *node = 0, *parent = 0;
1274 
1275         for(;;)
1276         {
1277             node = &w->wnodes[nidx];
1278             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1279             if( t <= T || node->left < 0 )
1280             {
1281                 node->complexity = 1;
1282                 node->tree_risk = node->node_risk;
1283                 node->tree_error = 0.;
1284                 if( fold >= 0 )
1285                 {
1286                     node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1287                     node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1288                 }
1289                 break;
1290             }
1291             nidx = node->left;
1292         }
1293 
1294         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1295              nidx = pidx, pidx = w->wnodes[pidx].parent )
1296         {
1297             node = &w->wnodes[nidx];
1298             parent = &w->wnodes[pidx];
1299             parent->complexity += node->complexity;
1300             parent->tree_risk += node->tree_risk;
1301             parent->tree_error += node->tree_error;
1302 
1303             parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1304                              - parent->tree_risk)/(parent->complexity - 1);
1305             min_alpha = std::min( min_alpha, parent->alpha );
1306         }
1307 
1308         if( pidx < 0 )
1309             break;
1310 
1311         node = &w->wnodes[nidx];
1312         parent = &w->wnodes[pidx];
1313         parent->complexity = node->complexity;
1314         parent->tree_risk = node->tree_risk;
1315         parent->tree_error = node->tree_error;
1316         nidx = parent->right;
1317     }
1318 
1319     return min_alpha;
1320 }
1321 
cutTree(int root,double T,int fold,double min_alpha)1322 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1323 {
1324     int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1325     WNode* node = &w->wnodes[root];
1326     if( node->left < 0 )
1327         return true;
1328 
1329     for(;;)
1330     {
1331         for(;;)
1332         {
1333             node = &w->wnodes[nidx];
1334             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1335             if( t <= T || node->left < 0 )
1336                 break;
1337             if( node->alpha <= min_alpha + FLT_EPSILON )
1338             {
1339                 if( fold >= 0 )
1340                     w->cv_Tn[nidx*cv_n + fold] = T;
1341                 else
1342                     node->Tn = T;
1343                 if( nidx == root )
1344                     return true;
1345                 break;
1346             }
1347             nidx = node->left;
1348         }
1349 
1350         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1351              nidx = pidx, pidx = w->wnodes[pidx].parent )
1352             ;
1353 
1354         if( pidx < 0 )
1355             break;
1356 
1357         nidx = w->wnodes[pidx].right;
1358     }
1359 
1360     return false;
1361 }
1362 
predictTrees(const Range & range,const Mat & sample,int flags) const1363 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1364 {
1365     CV_Assert( sample.type() == CV_32F );
1366 
1367     int predictType = flags & PREDICT_MASK;
1368     int nvars = (int)varIdx.size();
1369     if( nvars == 0 )
1370         nvars = (int)varType.size();
1371     int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1372     int catbufsize = ncats > 0 ? nvars : 0;
1373     AutoBuffer<int> buf(nclasses + catbufsize + 1);
1374     int* votes = buf;
1375     int* catbuf = votes + nclasses;
1376     const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1377     const uchar* vtype = &varType[0];
1378     const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1379     const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1380     const float* psample = sample.ptr<float>();
1381     const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1382     size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1383     double sum = 0.;
1384     int lastClassIdx = -1;
1385     const float MISSED_VAL = TrainData::missingValue();
1386 
1387     for( i = 0; i < catbufsize; i++ )
1388         catbuf[i] = -1;
1389 
1390     if( predictType == PREDICT_AUTO )
1391     {
1392         predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1393             PREDICT_SUM : PREDICT_MAX_VOTE;
1394     }
1395 
1396     if( predictType == PREDICT_MAX_VOTE )
1397     {
1398         for( i = 0; i < nclasses; i++ )
1399             votes[i] = 0;
1400     }
1401 
1402     for( int ridx = range.start; ridx < range.end; ridx++ )
1403     {
1404         int nidx = roots[ridx], prev = nidx, c = 0;
1405 
1406         for(;;)
1407         {
1408             prev = nidx;
1409             const Node& node = nodes[nidx];
1410             if( node.split < 0 )
1411                 break;
1412             const Split& split = splits[node.split];
1413             int vi = split.varIdx;
1414             int ci = cvidx ? cvidx[vi] : vi;
1415             float val = psample[ci*sstep];
1416             if( val == MISSED_VAL )
1417             {
1418                 if( !missingSubstPtr )
1419                 {
1420                     nidx = node.defaultDir < 0 ? node.left : node.right;
1421                     continue;
1422                 }
1423                 val = missingSubstPtr[vi];
1424             }
1425 
1426             if( vtype[vi] == VAR_ORDERED )
1427                 nidx = val <= split.c ? node.left : node.right;
1428             else
1429             {
1430                 if( flags & PREPROCESSED_INPUT )
1431                     c = cvRound(val);
1432                 else
1433                 {
1434                     c = catbuf[ci];
1435                     if( c < 0 )
1436                     {
1437                         int a = c = cofs[vi][0];
1438                         int b = cofs[vi][1];
1439 
1440                         int ival = cvRound(val);
1441                         if( ival != val )
1442                             CV_Error( CV_StsBadArg,
1443                                      "one of input categorical variable is not an integer" );
1444 
1445                         while( a < b )
1446                         {
1447                             c = (a + b) >> 1;
1448                             if( ival < cmap[c] )
1449                                 b = c;
1450                             else if( ival > cmap[c] )
1451                                 a = c+1;
1452                             else
1453                                 break;
1454                         }
1455 
1456                         CV_Assert( c >= 0 && ival == cmap[c] );
1457 
1458                         c -= cofs[vi][0];
1459                         catbuf[ci] = c;
1460                     }
1461                     const int* subset = &subsets[split.subsetOfs];
1462                     unsigned u = c;
1463                     nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1464                 }
1465             }
1466         }
1467 
1468         if( predictType == PREDICT_SUM )
1469             sum += nodes[prev].value;
1470         else
1471         {
1472             lastClassIdx = nodes[prev].classIdx;
1473             votes[lastClassIdx]++;
1474         }
1475     }
1476 
1477     if( predictType == PREDICT_MAX_VOTE )
1478     {
1479         int best_idx = lastClassIdx;
1480         if( range.end - range.start > 1 )
1481         {
1482             best_idx = 0;
1483             for( i = 1; i < nclasses; i++ )
1484                 if( votes[best_idx] < votes[i] )
1485                     best_idx = i;
1486         }
1487         sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1488     }
1489 
1490     return (float)sum;
1491 }
1492 
1493 
predict(InputArray _samples,OutputArray _results,int flags) const1494 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1495 {
1496     CV_Assert( !roots.empty() );
1497     Mat samples = _samples.getMat(), results;
1498     int i, nsamples = samples.rows;
1499     int rtype = CV_32F;
1500     bool needresults = _results.needed();
1501     float retval = 0.f;
1502     bool iscls = isClassifier();
1503     float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1504 
1505     if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1506         rtype = CV_32S;
1507 
1508     if( needresults )
1509     {
1510         _results.create(nsamples, 1, rtype);
1511         results = _results.getMat();
1512     }
1513     else
1514         nsamples = std::min(nsamples, 1);
1515 
1516     for( i = 0; i < nsamples; i++ )
1517     {
1518         float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1519         if( needresults )
1520         {
1521             if( rtype == CV_32F )
1522                 results.at<float>(i) = val;
1523             else
1524                 results.at<int>(i) = cvRound(val);
1525         }
1526         if( i == 0 )
1527             retval = val;
1528     }
1529     return retval;
1530 }
1531 
writeTrainingParams(FileStorage & fs) const1532 void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1533 {
1534     fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1535     fs << "max_categories" << params.getMaxCategories();
1536     fs << "regression_accuracy" << params.getRegressionAccuracy();
1537 
1538     fs << "max_depth" << params.getMaxDepth();
1539     fs << "min_sample_count" << params.getMinSampleCount();
1540     fs << "cross_validation_folds" << params.getCVFolds();
1541 
1542     if( params.getCVFolds() > 1 )
1543         fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1544 
1545     if( !params.priors.empty() )
1546         fs << "priors" << params.priors;
1547 }
1548 
writeParams(FileStorage & fs) const1549 void DTreesImpl::writeParams(FileStorage& fs) const
1550 {
1551     fs << "is_classifier" << isClassifier();
1552     fs << "var_all" << (int)varType.size();
1553     fs << "var_count" << getVarCount();
1554 
1555     int ord_var_count = 0, cat_var_count = 0;
1556     int i, n = (int)varType.size();
1557     for( i = 0; i < n; i++ )
1558         if( varType[i] == VAR_ORDERED )
1559             ord_var_count++;
1560         else
1561             cat_var_count++;
1562     fs << "ord_var_count" << ord_var_count;
1563     fs << "cat_var_count" << cat_var_count;
1564 
1565     fs << "training_params" << "{";
1566     writeTrainingParams(fs);
1567 
1568     fs << "}";
1569 
1570     if( !varIdx.empty() )
1571     {
1572         fs << "global_var_idx" << 1;
1573         fs << "var_idx" << varIdx;
1574     }
1575 
1576     fs << "var_type" << varType;
1577 
1578     if( !catOfs.empty() )
1579         fs << "cat_ofs" << catOfs;
1580     if( !catMap.empty() )
1581         fs << "cat_map" << catMap;
1582     if( !classLabels.empty() )
1583         fs << "class_labels" << classLabels;
1584     if( !missingSubst.empty() )
1585         fs << "missing_subst" << missingSubst;
1586 }
1587 
writeSplit(FileStorage & fs,int splitidx) const1588 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1589 {
1590     const Split& split = splits[splitidx];
1591 
1592     fs << "{:";
1593 
1594     int vi = split.varIdx;
1595     fs << "var" << vi;
1596     fs << "quality" << split.quality;
1597 
1598     if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1599     {
1600         int i, n = getCatCount(vi), to_right = 0;
1601         const int* subset = &subsets[split.subsetOfs];
1602         for( i = 0; i < n; i++ )
1603             to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1604 
1605         // ad-hoc rule when to use inverse categorical split notation
1606         // to achieve more compact and clear representation
1607         int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1608 
1609         fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1610 
1611         for( i = 0; i < n; i++ )
1612         {
1613             int dir = CV_DTREE_CAT_DIR(i, subset);
1614             if( dir*default_dir < 0 )
1615                 fs << i;
1616         }
1617 
1618         fs << "]";
1619     }
1620     else
1621         fs << (!split.inversed ? "le" : "gt") << split.c;
1622 
1623     fs << "}";
1624 }
1625 
writeNode(FileStorage & fs,int nidx,int depth) const1626 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1627 {
1628     const Node& node = nodes[nidx];
1629     fs << "{";
1630     fs << "depth" << depth;
1631     fs << "value" << node.value;
1632 
1633     if( _isClassifier )
1634         fs << "norm_class_idx" << node.classIdx;
1635 
1636     if( node.split >= 0 )
1637     {
1638         fs << "splits" << "[";
1639 
1640         for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1641             writeSplit( fs, splitidx );
1642 
1643         fs << "]";
1644     }
1645 
1646     fs << "}";
1647 }
1648 
writeTree(FileStorage & fs,int root) const1649 void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1650 {
1651     fs << "nodes" << "[";
1652 
1653     int nidx = root, pidx = 0, depth = 0;
1654     const Node *node = 0;
1655 
1656     // traverse the tree and save all the nodes in depth-first order
1657     for(;;)
1658     {
1659         for(;;)
1660         {
1661             writeNode( fs, nidx, depth );
1662             node = &nodes[nidx];
1663             if( node->left < 0 )
1664                 break;
1665             nidx = node->left;
1666             depth++;
1667         }
1668 
1669         for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1670              nidx = pidx, pidx = nodes[pidx].parent )
1671             depth--;
1672 
1673         if( pidx < 0 )
1674             break;
1675 
1676         nidx = nodes[pidx].right;
1677     }
1678 
1679     fs << "]";
1680 }
1681 
write(FileStorage & fs) const1682 void DTreesImpl::write( FileStorage& fs ) const
1683 {
1684     writeParams(fs);
1685     writeTree(fs, roots[0]);
1686 }
1687 
readParams(const FileNode & fn)1688 void DTreesImpl::readParams( const FileNode& fn )
1689 {
1690     _isClassifier = (int)fn["is_classifier"] != 0;
1691     /*int var_all = (int)fn["var_all"];
1692     int var_count = (int)fn["var_count"];
1693     int cat_var_count = (int)fn["cat_var_count"];
1694     int ord_var_count = (int)fn["ord_var_count"];*/
1695 
1696     FileNode tparams_node = fn["training_params"];
1697 
1698     TreeParams params0 = TreeParams();
1699 
1700     if( !tparams_node.empty() ) // training parameters are not necessary
1701     {
1702         params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1703         params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1704         params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1705         params0.setMaxDepth((int)tparams_node["max_depth"]);
1706         params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1707         params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1708 
1709         if( params0.getCVFolds() > 1 )
1710         {
1711             params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1712         }
1713 
1714         tparams_node["priors"] >> params0.priors;
1715     }
1716 
1717     readVectorOrMat(fn["var_idx"], varIdx);
1718     fn["var_type"] >> varType;
1719 
1720     int format = 0;
1721     fn["format"] >> format;
1722     bool isLegacy = format < 3;
1723 
1724     int varAll = (int)fn["var_all"];
1725     if (isLegacy && (int)varType.size() <= varAll)
1726     {
1727         std::vector<uchar> extendedTypes(varAll + 1, 0);
1728 
1729         int i = 0, n;
1730         if (!varIdx.empty())
1731         {
1732             n = (int)varIdx.size();
1733             for (; i < n; ++i)
1734             {
1735                 int var = varIdx[i];
1736                 extendedTypes[var] = varType[i];
1737             }
1738         }
1739         else
1740         {
1741             n = (int)varType.size();
1742             for (; i < n; ++i)
1743             {
1744                 extendedTypes[i] = varType[i];
1745             }
1746         }
1747         extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1748         extendedTypes.swap(varType);
1749     }
1750 
1751     readVectorOrMat(fn["cat_map"], catMap);
1752 
1753     if (isLegacy)
1754     {
1755         // generating "catOfs" from "cat_count"
1756         catOfs.clear();
1757         classLabels.clear();
1758         std::vector<int> counts;
1759         readVectorOrMat(fn["cat_count"], counts);
1760         unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1761         for (; i < size; ++i)
1762         {
1763             Vec2i newOffsets(0, 0);
1764             if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1765             {
1766                 newOffsets[0] = curShift;
1767                 curShift += counts[j];
1768                 newOffsets[1] = curShift;
1769                 ++j;
1770             }
1771             catOfs.push_back(newOffsets);
1772         }
1773         // other elements in "catMap" are "classLabels"
1774         if (curShift < catMap.size())
1775         {
1776             classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1777             catMap.erase(catMap.begin() + curShift, catMap.end());
1778         }
1779     }
1780     else
1781     {
1782         fn["cat_ofs"] >> catOfs;
1783         fn["missing_subst"] >> missingSubst;
1784         fn["class_labels"] >> classLabels;
1785     }
1786 
1787     // init var mapping for node reading (var indexes or varIdx indexes)
1788     bool globalVarIdx = false;
1789     fn["global_var_idx"] >> globalVarIdx;
1790     if (globalVarIdx || varIdx.empty())
1791         setRangeVector(varMapping, (int)varType.size());
1792     else
1793         varMapping = varIdx;
1794 
1795     initCompVarIdx();
1796     setDParams(params0);
1797 }
1798 
readSplit(const FileNode & fn)1799 int DTreesImpl::readSplit( const FileNode& fn )
1800 {
1801     Split split;
1802 
1803     int vi = (int)fn["var"];
1804     CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1805     vi = varMapping[vi]; // convert to varIdx if needed
1806     split.varIdx = vi;
1807 
1808     if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1809     {
1810         int i, val, ssize = getSubsetSize(vi);
1811         split.subsetOfs = (int)subsets.size();
1812         for( i = 0; i < ssize; i++ )
1813             subsets.push_back(0);
1814         int* subset = &subsets[split.subsetOfs];
1815         FileNode fns = fn["in"];
1816         if( fns.empty() )
1817         {
1818             fns = fn["not_in"];
1819             split.inversed = true;
1820         }
1821 
1822         if( fns.isInt() )
1823         {
1824             val = (int)fns;
1825             subset[val >> 5] |= 1 << (val & 31);
1826         }
1827         else
1828         {
1829             FileNodeIterator it = fns.begin();
1830             int n = (int)fns.size();
1831             for( i = 0; i < n; i++, ++it )
1832             {
1833                 val = (int)*it;
1834                 subset[val >> 5] |= 1 << (val & 31);
1835             }
1836         }
1837 
1838         // for categorical splits we do not use inversed splits,
1839         // instead we inverse the variable set in the split
1840         if( split.inversed )
1841         {
1842             for( i = 0; i < ssize; i++ )
1843                 subset[i] ^= -1;
1844             split.inversed = false;
1845         }
1846     }
1847     else
1848     {
1849         FileNode cmpNode = fn["le"];
1850         if( cmpNode.empty() )
1851         {
1852             cmpNode = fn["gt"];
1853             split.inversed = true;
1854         }
1855         split.c = (float)cmpNode;
1856     }
1857 
1858     split.quality = (float)fn["quality"];
1859     splits.push_back(split);
1860 
1861     return (int)(splits.size() - 1);
1862 }
1863 
readNode(const FileNode & fn)1864 int DTreesImpl::readNode( const FileNode& fn )
1865 {
1866     Node node;
1867     node.value = (double)fn["value"];
1868 
1869     if( _isClassifier )
1870         node.classIdx = (int)fn["norm_class_idx"];
1871 
1872     FileNode sfn = fn["splits"];
1873     if( !sfn.empty() )
1874     {
1875         int i, n = (int)sfn.size(), prevsplit = -1;
1876         FileNodeIterator it = sfn.begin();
1877 
1878         for( i = 0; i < n; i++, ++it )
1879         {
1880             int splitidx = readSplit(*it);
1881             if( splitidx < 0 )
1882                 break;
1883             if( prevsplit < 0 )
1884                 node.split = splitidx;
1885             else
1886                 splits[prevsplit].next = splitidx;
1887             prevsplit = splitidx;
1888         }
1889     }
1890     nodes.push_back(node);
1891     return (int)(nodes.size() - 1);
1892 }
1893 
readTree(const FileNode & fn)1894 int DTreesImpl::readTree( const FileNode& fn )
1895 {
1896     int i, n = (int)fn.size(), root = -1, pidx = -1;
1897     FileNodeIterator it = fn.begin();
1898 
1899     for( i = 0; i < n; i++, ++it )
1900     {
1901         int nidx = readNode(*it);
1902         if( nidx < 0 )
1903             break;
1904         Node& node = nodes[nidx];
1905         node.parent = pidx;
1906         if( pidx < 0 )
1907             root = nidx;
1908         else
1909         {
1910             Node& parent = nodes[pidx];
1911             if( parent.left < 0 )
1912                 parent.left = nidx;
1913             else
1914                 parent.right = nidx;
1915         }
1916         if( node.split >= 0 )
1917             pidx = nidx;
1918         else
1919         {
1920             while( pidx >= 0 && nodes[pidx].right >= 0 )
1921                 pidx = nodes[pidx].parent;
1922         }
1923     }
1924     roots.push_back(root);
1925     return root;
1926 }
1927 
read(const FileNode & fn)1928 void DTreesImpl::read( const FileNode& fn )
1929 {
1930     clear();
1931     readParams(fn);
1932 
1933     FileNode fnodes = fn["nodes"];
1934     CV_Assert( !fnodes.empty() );
1935     readTree(fnodes);
1936 }
1937 
create()1938 Ptr<DTrees> DTrees::create()
1939 {
1940     return makePtr<DTreesImpl>();
1941 }
1942 
1943 }
1944 }
1945 
1946 /* End of file. */
1947