1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #ifndef __OPENCV_ML_HPP__
42 #define __OPENCV_ML_HPP__
43
44 #ifdef __cplusplus
45 # include "opencv2/core.hpp"
46 #endif
47
48 #include "opencv2/core/core_c.h"
49 #include <limits.h>
50
51 #ifdef __cplusplus
52
53 #include <map>
54 #include <iostream>
55
56 // Apple defines a check() macro somewhere in the debug headers
57 // that interferes with a method definiton in this header
58 #undef check
59
60 /****************************************************************************************\
61 * Main struct definitions *
62 \****************************************************************************************/
63
64 /* log(2*PI) */
65 #define CV_LOG2PI (1.8378770664093454835606594728112)
66
67 /* columns of <trainData> matrix are training samples */
68 #define CV_COL_SAMPLE 0
69
70 /* rows of <trainData> matrix are training samples */
71 #define CV_ROW_SAMPLE 1
72
73 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
74
75 struct CvVectors
76 {
77 int type;
78 int dims, count;
79 CvVectors* next;
80 union
81 {
82 uchar** ptr;
83 float** fl;
84 double** db;
85 } data;
86 };
87
88 #if 0
89 /* A structure, representing the lattice range of statmodel parameters.
90 It is used for optimizing statmodel parameters by cross-validation method.
91 The lattice is logarithmic, so <step> must be greater then 1. */
92 typedef struct CvParamLattice
93 {
94 double min_val;
95 double max_val;
96 double step;
97 }
98 CvParamLattice;
99
100 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
101 double log_step )
102 {
103 CvParamLattice pl;
104 pl.min_val = MIN( min_val, max_val );
105 pl.max_val = MAX( min_val, max_val );
106 pl.step = MAX( log_step, 1. );
107 return pl;
108 }
109
110 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
111 {
112 CvParamLattice pl = {0,0,0};
113 return pl;
114 }
115 #endif
116
117 /* Variable type */
118 #define CV_VAR_NUMERICAL 0
119 #define CV_VAR_ORDERED 0
120 #define CV_VAR_CATEGORICAL 1
121
122 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
123 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
124 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
125 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
126 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
127 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
128 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
129 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
130 #define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
131 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
132
133 #define CV_TRAIN_ERROR 0
134 #define CV_TEST_ERROR 1
135
136 class CvStatModel
137 {
138 public:
139 CvStatModel();
140 virtual ~CvStatModel();
141
142 virtual void clear();
143
144 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
145 CV_WRAP virtual void load( const char* filename, const char* name=0 );
146
147 virtual void write( CvFileStorage* storage, const char* name ) const;
148 virtual void read( CvFileStorage* storage, CvFileNode* node );
149
150 protected:
151 const char* default_model_name;
152 };
153
154 /****************************************************************************************\
155 * Normal Bayes Classifier *
156 \****************************************************************************************/
157
158 /* The structure, representing the grid range of statmodel parameters.
159 It is used for optimizing statmodel accuracy by varying model parameters,
160 the accuracy estimate being computed by cross-validation.
161 The grid is logarithmic, so <step> must be greater then 1. */
162
163 class CvMLData;
164
165 struct CvParamGrid
166 {
167 // SVM params type
168 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
169
CvParamGridCvParamGrid170 CvParamGrid()
171 {
172 min_val = max_val = step = 0;
173 }
174
175 CvParamGrid( double min_val, double max_val, double log_step );
176 //CvParamGrid( int param_id );
177 bool check() const;
178
179 CV_PROP_RW double min_val;
180 CV_PROP_RW double max_val;
181 CV_PROP_RW double step;
182 };
183
CvParamGrid(double _min_val,double _max_val,double _log_step)184 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
185 {
186 min_val = _min_val;
187 max_val = _max_val;
188 step = _log_step;
189 }
190
191 class CvNormalBayesClassifier : public CvStatModel
192 {
193 public:
194 CV_WRAP CvNormalBayesClassifier();
195 virtual ~CvNormalBayesClassifier();
196
197 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
198 const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
199
200 virtual bool train( const CvMat* trainData, const CvMat* responses,
201 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
202
203 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
204 CV_WRAP virtual void clear();
205
206 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
207 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
208 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
209 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
210 bool update=false );
211 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
212
213 virtual void write( CvFileStorage* storage, const char* name ) const;
214 virtual void read( CvFileStorage* storage, CvFileNode* node );
215
216 protected:
217 int var_count, var_all;
218 CvMat* var_idx;
219 CvMat* cls_labels;
220 CvMat** count;
221 CvMat** sum;
222 CvMat** productsum;
223 CvMat** avg;
224 CvMat** inv_eigen_values;
225 CvMat** cov_rotate_mats;
226 CvMat* c;
227 };
228
229
230 /****************************************************************************************\
231 * K-Nearest Neighbour Classifier *
232 \****************************************************************************************/
233
234 // k Nearest Neighbors
235 class CvKNearest : public CvStatModel
236 {
237 public:
238
239 CV_WRAP CvKNearest();
240 virtual ~CvKNearest();
241
242 CvKNearest( const CvMat* trainData, const CvMat* responses,
243 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
244
245 virtual bool train( const CvMat* trainData, const CvMat* responses,
246 const CvMat* sampleIdx=0, bool is_regression=false,
247 int maxK=32, bool updateBase=false );
248
249 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
250 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
251
252 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
253 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
254
255 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
256 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
257 int maxK=32, bool updateBase=false );
258
259 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
260 const float** neighbors=0, cv::Mat* neighborResponses=0,
261 cv::Mat* dist=0 ) const;
262 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
263 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
264
265 virtual void clear();
266 int get_max_k() const;
267 int get_var_count() const;
268 int get_sample_count() const;
269 bool is_regression() const;
270
271 virtual float write_results( int k, int k1, int start, int end,
272 const float* neighbor_responses, const float* dist, CvMat* _results,
273 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
274
275 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
276 float* neighbor_responses, const float** neighbors, float* dist ) const;
277
278 protected:
279
280 int max_k, var_count;
281 int total;
282 bool regression;
283 CvVectors* samples;
284 };
285
286 /****************************************************************************************\
287 * Support Vector Machines *
288 \****************************************************************************************/
289
290 // SVM training parameters
291 struct CvSVMParams
292 {
293 CvSVMParams();
294 CvSVMParams( int svm_type, int kernel_type,
295 double degree, double gamma, double coef0,
296 double Cvalue, double nu, double p,
297 CvMat* class_weights, CvTermCriteria term_crit );
298
299 CV_PROP_RW int svm_type;
300 CV_PROP_RW int kernel_type;
301 CV_PROP_RW double degree; // for poly
302 CV_PROP_RW double gamma; // for poly/rbf/sigmoid/chi2
303 CV_PROP_RW double coef0; // for poly/sigmoid
304
305 CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
306 CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
307 CV_PROP_RW double p; // for CV_SVM_EPS_SVR
308 CvMat* class_weights; // for CV_SVM_C_SVC
309 CV_PROP_RW CvTermCriteria term_crit; // termination criteria
310 };
311
312
313 struct CvSVMKernel
314 {
315 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
316 const float* another, float* results );
317 CvSVMKernel();
318 CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
319 virtual bool create( const CvSVMParams* params, Calc _calc_func );
320 virtual ~CvSVMKernel();
321
322 virtual void clear();
323 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
324
325 const CvSVMParams* params;
326 Calc calc_func;
327
328 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
329 const float* another, float* results,
330 double alpha, double beta );
331 virtual void calc_intersec( int vcount, int var_count, const float** vecs,
332 const float* another, float* results );
333 virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
334 const float* another, float* results );
335 virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
336 const float* another, float* results );
337 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
338 const float* another, float* results );
339 virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
340 const float* another, float* results );
341 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
342 const float* another, float* results );
343 };
344
345
346 struct CvSVMKernelRow
347 {
348 CvSVMKernelRow* prev;
349 CvSVMKernelRow* next;
350 float* data;
351 };
352
353
354 struct CvSVMSolutionInfo
355 {
356 double obj;
357 double rho;
358 double upper_bound_p;
359 double upper_bound_n;
360 double r; // for Solver_NU
361 };
362
363 class CvSVMSolver
364 {
365 public:
366 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
367 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
368 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
369
370 CvSVMSolver();
371
372 CvSVMSolver( int count, int var_count, const float** samples, schar* y,
373 int alpha_count, double* alpha, double Cp, double Cn,
374 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
375 SelectWorkingSet select_working_set, CalcRho calc_rho );
376 virtual bool create( int count, int var_count, const float** samples, schar* y,
377 int alpha_count, double* alpha, double Cp, double Cn,
378 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
379 SelectWorkingSet select_working_set, CalcRho calc_rho );
380 virtual ~CvSVMSolver();
381
382 virtual void clear();
383 virtual bool solve_generic( CvSVMSolutionInfo& si );
384
385 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
386 double Cp, double Cn, CvMemStorage* storage,
387 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
388 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
389 CvMemStorage* storage, CvSVMKernel* kernel,
390 double* alpha, CvSVMSolutionInfo& si );
391 virtual bool solve_one_class( int count, int var_count, const float** samples,
392 CvMemStorage* storage, CvSVMKernel* kernel,
393 double* alpha, CvSVMSolutionInfo& si );
394
395 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
396 CvMemStorage* storage, CvSVMKernel* kernel,
397 double* alpha, CvSVMSolutionInfo& si );
398
399 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
400 CvMemStorage* storage, CvSVMKernel* kernel,
401 double* alpha, CvSVMSolutionInfo& si );
402
403 virtual float* get_row_base( int i, bool* _existed );
404 virtual float* get_row( int i, float* dst );
405
406 int sample_count;
407 int var_count;
408 int cache_size;
409 int cache_line_size;
410 const float** samples;
411 const CvSVMParams* params;
412 CvMemStorage* storage;
413 CvSVMKernelRow lru_list;
414 CvSVMKernelRow* rows;
415
416 int alpha_count;
417
418 double* G;
419 double* alpha;
420
421 // -1 - lower bound, 0 - free, 1 - upper bound
422 schar* alpha_status;
423
424 schar* y;
425 double* b;
426 float* buf[2];
427 double eps;
428 int max_iter;
429 double C[2]; // C[0] == Cn, C[1] == Cp
430 CvSVMKernel* kernel;
431
432 SelectWorkingSet select_working_set_func;
433 CalcRho calc_rho_func;
434 GetRow get_row_func;
435
436 virtual bool select_working_set( int& i, int& j );
437 virtual bool select_working_set_nu_svm( int& i, int& j );
438 virtual void calc_rho( double& rho, double& r );
439 virtual void calc_rho_nu_svm( double& rho, double& r );
440
441 virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
442 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
443 virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
444 };
445
446
447 struct CvSVMDecisionFunc
448 {
449 double rho;
450 int sv_count;
451 double* alpha;
452 int* sv_index;
453 };
454
455
456 // SVM model
457 class CvSVM : public CvStatModel
458 {
459 public:
460 // SVM type
461 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
462
463 // SVM kernel type
464 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
465
466 // SVM params type
467 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
468
469 CV_WRAP CvSVM();
470 virtual ~CvSVM();
471
472 CvSVM( const CvMat* trainData, const CvMat* responses,
473 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
474 CvSVMParams params=CvSVMParams() );
475
476 virtual bool train( const CvMat* trainData, const CvMat* responses,
477 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
478 CvSVMParams params=CvSVMParams() );
479
480 virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
481 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
482 int kfold = 10,
483 CvParamGrid Cgrid = get_default_grid(CvSVM::C),
484 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
485 CvParamGrid pGrid = get_default_grid(CvSVM::P),
486 CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
487 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
488 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
489 bool balanced=false );
490
491 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
492 virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
493
494 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
495 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
496 CvSVMParams params=CvSVMParams() );
497
498 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
499 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
500 CvSVMParams params=CvSVMParams() );
501
502 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
503 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
504 int k_fold = 10,
505 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
506 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
507 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
508 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
509 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
510 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
511 bool balanced=false);
512 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
513 CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
514
515 CV_WRAP virtual int get_support_vector_count() const;
516 virtual const float* get_support_vector(int i) const;
get_params() const517 virtual CvSVMParams get_params() const { return params; }
518 CV_WRAP virtual void clear();
519
get_decision_function() const520 virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
521
522 static CvParamGrid get_default_grid( int param_id );
523
524 virtual void write( CvFileStorage* storage, const char* name ) const;
525 virtual void read( CvFileStorage* storage, CvFileNode* node );
get_var_count() const526 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
527
528 protected:
529
530 virtual bool set_params( const CvSVMParams& params );
531 virtual bool train1( int sample_count, int var_count, const float** samples,
532 const void* responses, double Cp, double Cn,
533 CvMemStorage* _storage, double* alpha, double& rho );
534 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
535 const CvMat* responses, CvMemStorage* _storage, double* alpha );
536 virtual void create_kernel();
537 virtual void create_solver();
538
539 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
540
541 virtual void write_params( CvFileStorage* fs ) const;
542 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
543
544 void optimize_linear_svm();
545
546 CvSVMParams params;
547 CvMat* class_labels;
548 int var_all;
549 float** sv;
550 int sv_total;
551 CvMat* var_idx;
552 CvMat* class_weights;
553 CvSVMDecisionFunc* decision_func;
554 CvMemStorage* storage;
555
556 CvSVMSolver* solver;
557 CvSVMKernel* kernel;
558
559 private:
560 CvSVM(const CvSVM&);
561 CvSVM& operator = (const CvSVM&);
562 };
563
564 /****************************************************************************************\
565 * Decision Tree *
566 \****************************************************************************************/\
567 struct CvPair16u32s
568 {
569 unsigned short* u;
570 int* i;
571 };
572
573
574 #define CV_DTREE_CAT_DIR(idx,subset) \
575 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
576
577 struct CvDTreeSplit
578 {
579 int var_idx;
580 int condensed_idx;
581 int inversed;
582 float quality;
583 CvDTreeSplit* next;
584 union
585 {
586 int subset[2];
587 struct
588 {
589 float c;
590 int split_point;
591 }
592 ord;
593 };
594 };
595
596 struct CvDTreeNode
597 {
598 int class_idx;
599 int Tn;
600 double value;
601
602 CvDTreeNode* parent;
603 CvDTreeNode* left;
604 CvDTreeNode* right;
605
606 CvDTreeSplit* split;
607
608 int sample_count;
609 int depth;
610 int* num_valid;
611 int offset;
612 int buf_idx;
613 double maxlr;
614
615 // global pruning data
616 int complexity;
617 double alpha;
618 double node_risk, tree_risk, tree_error;
619
620 // cross-validation pruning data
621 int* cv_Tn;
622 double* cv_node_risk;
623 double* cv_node_error;
624
get_num_validCvDTreeNode625 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
set_num_validCvDTreeNode626 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
627 };
628
629
630 struct CvDTreeParams
631 {
632 CV_PROP_RW int max_categories;
633 CV_PROP_RW int max_depth;
634 CV_PROP_RW int min_sample_count;
635 CV_PROP_RW int cv_folds;
636 CV_PROP_RW bool use_surrogates;
637 CV_PROP_RW bool use_1se_rule;
638 CV_PROP_RW bool truncate_pruned_tree;
639 CV_PROP_RW float regression_accuracy;
640 const float* priors;
641
642 CvDTreeParams();
643 CvDTreeParams( int max_depth, int min_sample_count,
644 float regression_accuracy, bool use_surrogates,
645 int max_categories, int cv_folds,
646 bool use_1se_rule, bool truncate_pruned_tree,
647 const float* priors );
648 };
649
650
651 struct CvDTreeTrainData
652 {
653 CvDTreeTrainData();
654 CvDTreeTrainData( const CvMat* trainData, int tflag,
655 const CvMat* responses, const CvMat* varIdx=0,
656 const CvMat* sampleIdx=0, const CvMat* varType=0,
657 const CvMat* missingDataMask=0,
658 const CvDTreeParams& params=CvDTreeParams(),
659 bool _shared=false, bool _add_labels=false );
660 virtual ~CvDTreeTrainData();
661
662 virtual void set_data( const CvMat* trainData, int tflag,
663 const CvMat* responses, const CvMat* varIdx=0,
664 const CvMat* sampleIdx=0, const CvMat* varType=0,
665 const CvMat* missingDataMask=0,
666 const CvDTreeParams& params=CvDTreeParams(),
667 bool _shared=false, bool _add_labels=false,
668 bool _update_data=false );
669 virtual void do_responses_copy();
670
671 virtual void get_vectors( const CvMat* _subsample_idx,
672 float* values, uchar* missing, float* responses, bool get_class_idx=false );
673
674 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
675
676 virtual void write_params( CvFileStorage* fs ) const;
677 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
678
679 // release all the data
680 virtual void clear();
681
682 int get_num_classes() const;
683 int get_var_type(int vi) const;
get_work_var_countCvDTreeTrainData684 int get_work_var_count() const {return work_var_count;}
685
686 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
687 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
688 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
689 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
690 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
691 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
692 const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
693 virtual int get_child_buf_idx( CvDTreeNode* n );
694
695 ////////////////////////////////////
696
697 virtual bool set_params( const CvDTreeParams& params );
698 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
699 int storage_idx, int offset );
700
701 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
702 int split_point, int inversed, float quality );
703 virtual CvDTreeSplit* new_split_cat( int vi, float quality );
704 virtual void free_node_data( CvDTreeNode* node );
705 virtual void free_train_data();
706 virtual void free_node( CvDTreeNode* node );
707
708 int sample_count, var_all, var_count, max_c_count;
709 int ord_var_count, cat_var_count, work_var_count;
710 bool have_labels, have_priors;
711 bool is_classifier;
712 int tflag;
713
714 const CvMat* train_data;
715 const CvMat* responses;
716 CvMat* responses_copy; // used in Boosting
717
718 int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
719 bool shared;
720 int is_buf_16u;
721
722 CvMat* cat_count;
723 CvMat* cat_ofs;
724 CvMat* cat_map;
725
726 CvMat* counts;
727 CvMat* buf;
get_length_subbufCvDTreeTrainData728 inline size_t get_length_subbuf() const
729 {
730 size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
731 return res;
732 }
733
734 CvMat* direction;
735 CvMat* split_buf;
736
737 CvMat* var_idx;
738 CvMat* var_type; // i-th element =
739 // k<0 - ordered
740 // k>=0 - categorical, see k-th element of cat_* arrays
741 CvMat* priors;
742 CvMat* priors_mult;
743
744 CvDTreeParams params;
745
746 CvMemStorage* tree_storage;
747 CvMemStorage* temp_storage;
748
749 CvDTreeNode* data_root;
750
751 CvSet* node_heap;
752 CvSet* split_heap;
753 CvSet* cv_heap;
754 CvSet* nv_heap;
755
756 cv::RNG* rng;
757 };
758
759 class CvDTree;
760 class CvForestTree;
761
762 namespace cv
763 {
764 struct DTreeBestSplitFinder;
765 struct ForestTreeBestSplitFinder;
766 }
767
768 class CvDTree : public CvStatModel
769 {
770 public:
771 CV_WRAP CvDTree();
772 virtual ~CvDTree();
773
774 virtual bool train( const CvMat* trainData, int tflag,
775 const CvMat* responses, const CvMat* varIdx=0,
776 const CvMat* sampleIdx=0, const CvMat* varType=0,
777 const CvMat* missingDataMask=0,
778 CvDTreeParams params=CvDTreeParams() );
779
780 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
781
782 // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
783 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
784
785 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
786
787 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
788 bool preprocessedInput=false ) const;
789
790 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
791 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
792 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
793 const cv::Mat& missingDataMask=cv::Mat(),
794 CvDTreeParams params=CvDTreeParams() );
795
796 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
797 bool preprocessedInput=false ) const;
798 CV_WRAP virtual cv::Mat getVarImportance();
799
800 virtual const CvMat* get_var_importance();
801 CV_WRAP virtual void clear();
802
803 virtual void read( CvFileStorage* fs, CvFileNode* node );
804 virtual void write( CvFileStorage* fs, const char* name ) const;
805
806 // special read & write methods for trees in the tree ensembles
807 virtual void read( CvFileStorage* fs, CvFileNode* node,
808 CvDTreeTrainData* data );
809 virtual void write( CvFileStorage* fs ) const;
810
811 const CvDTreeNode* get_root() const;
812 int get_pruned_tree_idx() const;
813 CvDTreeTrainData* get_data();
814
815 protected:
816 friend struct cv::DTreeBestSplitFinder;
817
818 virtual bool do_train( const CvMat* _subsample_idx );
819
820 virtual void try_split_node( CvDTreeNode* n );
821 virtual void split_node_data( CvDTreeNode* n );
822 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
823 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
824 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
825 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
826 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
827 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
828 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
829 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
830 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
831 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
832 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
833 virtual double calc_node_dir( CvDTreeNode* node );
834 virtual void complete_node_dir( CvDTreeNode* node );
835 virtual void cluster_categories( const int* vectors, int vector_count,
836 int var_count, int* sums, int k, int* cluster_labels );
837
838 virtual void calc_node_value( CvDTreeNode* node );
839
840 virtual void prune_cv();
841 virtual double update_tree_rnc( int T, int fold );
842 virtual int cut_tree( int T, int fold, double min_alpha );
843 virtual void free_prune_data(bool cut_tree);
844 virtual void free_tree();
845
846 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
847 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
848 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
849 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
850 virtual void write_tree_nodes( CvFileStorage* fs ) const;
851 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
852
853 CvDTreeNode* root;
854 CvMat* var_importance;
855 CvDTreeTrainData* data;
856 CvMat train_data_hdr, responses_hdr;
857 cv::Mat train_data_mat, responses_mat;
858
859 public:
860 int pruned_tree_idx;
861 };
862
863
864 /****************************************************************************************\
865 * Random Trees Classifier *
866 \****************************************************************************************/
867
868 class CvRTrees;
869
870 class CvForestTree: public CvDTree
871 {
872 public:
873 CvForestTree();
874 virtual ~CvForestTree();
875
876 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
877
get_var_count() const878 virtual int get_var_count() const {return data ? data->var_count : 0;}
879 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
880
881 /* dummy methods to avoid warnings: BEGIN */
882 virtual bool train( const CvMat* trainData, int tflag,
883 const CvMat* responses, const CvMat* varIdx=0,
884 const CvMat* sampleIdx=0, const CvMat* varType=0,
885 const CvMat* missingDataMask=0,
886 CvDTreeParams params=CvDTreeParams() );
887
888 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
889 virtual void read( CvFileStorage* fs, CvFileNode* node );
890 virtual void read( CvFileStorage* fs, CvFileNode* node,
891 CvDTreeTrainData* data );
892 /* dummy methods to avoid warnings: END */
893
894 protected:
895 friend struct cv::ForestTreeBestSplitFinder;
896
897 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
898 CvRTrees* forest;
899 };
900
901
902 struct CvRTParams : public CvDTreeParams
903 {
904 //Parameters for the forest
905 CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
906 CV_PROP_RW int nactive_vars;
907 CV_PROP_RW CvTermCriteria term_crit;
908
909 CvRTParams();
910 CvRTParams( int max_depth, int min_sample_count,
911 float regression_accuracy, bool use_surrogates,
912 int max_categories, const float* priors, bool calc_var_importance,
913 int nactive_vars, int max_num_of_trees_in_the_forest,
914 float forest_accuracy, int termcrit_type );
915 };
916
917
918 class CvRTrees : public CvStatModel
919 {
920 public:
921 CV_WRAP CvRTrees();
922 virtual ~CvRTrees();
923 virtual bool train( const CvMat* trainData, int tflag,
924 const CvMat* responses, const CvMat* varIdx=0,
925 const CvMat* sampleIdx=0, const CvMat* varType=0,
926 const CvMat* missingDataMask=0,
927 CvRTParams params=CvRTParams() );
928
929 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
930 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
931 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
932
933 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
934 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
935 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
936 const cv::Mat& missingDataMask=cv::Mat(),
937 CvRTParams params=CvRTParams() );
938 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
939 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
940 CV_WRAP virtual cv::Mat getVarImportance();
941
942 CV_WRAP virtual void clear();
943
944 virtual const CvMat* get_var_importance();
945 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
946 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
947
948 virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
949
950 virtual float get_train_error();
951
952 virtual void read( CvFileStorage* fs, CvFileNode* node );
953 virtual void write( CvFileStorage* fs, const char* name ) const;
954
955 CvMat* get_active_var_mask();
956 CvRNG* get_rng();
957
958 int get_tree_count() const;
959 CvForestTree* get_tree(int i) const;
960
961 protected:
962 virtual cv::String getName() const;
963
964 virtual bool grow_forest( const CvTermCriteria term_crit );
965
966 // array of the trees of the forest
967 CvForestTree** trees;
968 CvDTreeTrainData* data;
969 CvMat train_data_hdr, responses_hdr;
970 cv::Mat train_data_mat, responses_mat;
971 int ntrees;
972 int nclasses;
973 double oob_error;
974 CvMat* var_importance;
975 int nsamples;
976
977 cv::RNG* rng;
978 CvMat* active_var_mask;
979 };
980
981 /****************************************************************************************\
982 * Extremely randomized trees Classifier *
983 \****************************************************************************************/
984 struct CvERTreeTrainData : public CvDTreeTrainData
985 {
986 virtual void set_data( const CvMat* trainData, int tflag,
987 const CvMat* responses, const CvMat* varIdx=0,
988 const CvMat* sampleIdx=0, const CvMat* varType=0,
989 const CvMat* missingDataMask=0,
990 const CvDTreeParams& params=CvDTreeParams(),
991 bool _shared=false, bool _add_labels=false,
992 bool _update_data=false );
993 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
994 const float** ord_values, const int** missing, int* sample_buf = 0 );
995 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
996 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
997 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
998 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
999 float* responses, bool get_class_idx=false );
1000 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1001 const CvMat* missing_mask;
1002 };
1003
1004 class CvForestERTree : public CvForestTree
1005 {
1006 protected:
1007 virtual double calc_node_dir( CvDTreeNode* node );
1008 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1009 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1010 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1011 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1012 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1013 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1014 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1015 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1016 virtual void split_node_data( CvDTreeNode* n );
1017 };
1018
1019 class CvERTrees : public CvRTrees
1020 {
1021 public:
1022 CV_WRAP CvERTrees();
1023 virtual ~CvERTrees();
1024 virtual bool train( const CvMat* trainData, int tflag,
1025 const CvMat* responses, const CvMat* varIdx=0,
1026 const CvMat* sampleIdx=0, const CvMat* varType=0,
1027 const CvMat* missingDataMask=0,
1028 CvRTParams params=CvRTParams());
1029 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1030 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1031 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1032 const cv::Mat& missingDataMask=cv::Mat(),
1033 CvRTParams params=CvRTParams());
1034 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1035 protected:
1036 virtual cv::String getName() const;
1037 virtual bool grow_forest( const CvTermCriteria term_crit );
1038 };
1039
1040
1041 /****************************************************************************************\
1042 * Boosted tree classifier *
1043 \****************************************************************************************/
1044
1045 struct CvBoostParams : public CvDTreeParams
1046 {
1047 CV_PROP_RW int boost_type;
1048 CV_PROP_RW int weak_count;
1049 CV_PROP_RW int split_criteria;
1050 CV_PROP_RW double weight_trim_rate;
1051
1052 CvBoostParams();
1053 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1054 int max_depth, bool use_surrogates, const float* priors );
1055 };
1056
1057
1058 class CvBoost;
1059
1060 class CvBoostTree: public CvDTree
1061 {
1062 public:
1063 CvBoostTree();
1064 virtual ~CvBoostTree();
1065
1066 virtual bool train( CvDTreeTrainData* trainData,
1067 const CvMat* subsample_idx, CvBoost* ensemble );
1068
1069 virtual void scale( double s );
1070 virtual void read( CvFileStorage* fs, CvFileNode* node,
1071 CvBoost* ensemble, CvDTreeTrainData* _data );
1072 virtual void clear();
1073
1074 /* dummy methods to avoid warnings: BEGIN */
1075 virtual bool train( const CvMat* trainData, int tflag,
1076 const CvMat* responses, const CvMat* varIdx=0,
1077 const CvMat* sampleIdx=0, const CvMat* varType=0,
1078 const CvMat* missingDataMask=0,
1079 CvDTreeParams params=CvDTreeParams() );
1080 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1081
1082 virtual void read( CvFileStorage* fs, CvFileNode* node );
1083 virtual void read( CvFileStorage* fs, CvFileNode* node,
1084 CvDTreeTrainData* data );
1085 /* dummy methods to avoid warnings: END */
1086
1087 protected:
1088
1089 virtual void try_split_node( CvDTreeNode* n );
1090 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1091 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1092 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1093 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1094 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1095 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1096 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1097 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1098 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1099 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1100 virtual void calc_node_value( CvDTreeNode* n );
1101 virtual double calc_node_dir( CvDTreeNode* n );
1102
1103 CvBoost* ensemble;
1104 };
1105
1106
1107 class CvBoost : public CvStatModel
1108 {
1109 public:
1110 // Boosting type
1111 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1112
1113 // Splitting criteria
1114 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1115
1116 CV_WRAP CvBoost();
1117 virtual ~CvBoost();
1118
1119 CvBoost( const CvMat* trainData, int tflag,
1120 const CvMat* responses, const CvMat* varIdx=0,
1121 const CvMat* sampleIdx=0, const CvMat* varType=0,
1122 const CvMat* missingDataMask=0,
1123 CvBoostParams params=CvBoostParams() );
1124
1125 virtual bool train( const CvMat* trainData, int tflag,
1126 const CvMat* responses, const CvMat* varIdx=0,
1127 const CvMat* sampleIdx=0, const CvMat* varType=0,
1128 const CvMat* missingDataMask=0,
1129 CvBoostParams params=CvBoostParams(),
1130 bool update=false );
1131
1132 virtual bool train( CvMLData* data,
1133 CvBoostParams params=CvBoostParams(),
1134 bool update=false );
1135
1136 virtual float predict( const CvMat* sample, const CvMat* missing=0,
1137 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1138 bool raw_mode=false, bool return_sum=false ) const;
1139
1140 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1141 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1142 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1143 const cv::Mat& missingDataMask=cv::Mat(),
1144 CvBoostParams params=CvBoostParams() );
1145
1146 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1147 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1148 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1149 const cv::Mat& missingDataMask=cv::Mat(),
1150 CvBoostParams params=CvBoostParams(),
1151 bool update=false );
1152
1153 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1154 const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1155 bool returnSum=false ) const;
1156
1157 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1158
1159 CV_WRAP virtual void prune( CvSlice slice );
1160
1161 CV_WRAP virtual void clear();
1162
1163 virtual void write( CvFileStorage* storage, const char* name ) const;
1164 virtual void read( CvFileStorage* storage, CvFileNode* node );
1165 virtual const CvMat* get_active_vars(bool absolute_idx=true);
1166
1167 CvSeq* get_weak_predictors();
1168
1169 CvMat* get_weights();
1170 CvMat* get_subtree_weights();
1171 CvMat* get_weak_response();
1172 const CvBoostParams& get_params() const;
1173 const CvDTreeTrainData* get_data() const;
1174
1175 protected:
1176
1177 virtual bool set_params( const CvBoostParams& params );
1178 virtual void update_weights( CvBoostTree* tree );
1179 virtual void trim_weights();
1180 virtual void write_params( CvFileStorage* fs ) const;
1181 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1182
1183 virtual void initialize_weights(double (&p)[2]);
1184
1185 CvDTreeTrainData* data;
1186 CvMat train_data_hdr, responses_hdr;
1187 cv::Mat train_data_mat, responses_mat;
1188 CvBoostParams params;
1189 CvSeq* weak;
1190
1191 CvMat* active_vars;
1192 CvMat* active_vars_abs;
1193 bool have_active_cat_vars;
1194
1195 CvMat* orig_response;
1196 CvMat* sum_response;
1197 CvMat* weak_eval;
1198 CvMat* subsample_mask;
1199 CvMat* weights;
1200 CvMat* subtree_weights;
1201 bool have_subsample;
1202 };
1203
1204
1205 /****************************************************************************************\
1206 * Gradient Boosted Trees *
1207 \****************************************************************************************/
1208
1209 // DataType: STRUCT CvGBTreesParams
1210 // Parameters of GBT (Gradient Boosted trees model), including single
1211 // tree settings and ensemble parameters.
1212 //
1213 // weak_count - count of trees in the ensemble
1214 // loss_function_type - loss function used for ensemble training
1215 // subsample_portion - portion of whole training set used for
1216 // every single tree training.
1217 // subsample_portion value is in (0.0, 1.0].
1218 // subsample_portion == 1.0 when whole dataset is
1219 // used on each step. Count of sample used on each
1220 // step is computed as
1221 // int(total_samples_count * subsample_portion).
1222 // shrinkage - regularization parameter.
1223 // Each tree prediction is multiplied on shrinkage value.
1224
1225
1226 struct CvGBTreesParams : public CvDTreeParams
1227 {
1228 CV_PROP_RW int weak_count;
1229 CV_PROP_RW int loss_function_type;
1230 CV_PROP_RW float subsample_portion;
1231 CV_PROP_RW float shrinkage;
1232
1233 CvGBTreesParams();
1234 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1235 float subsample_portion, int max_depth, bool use_surrogates );
1236 };
1237
1238 // DataType: CLASS CvGBTrees
1239 // Gradient Boosting Trees (GBT) algorithm implementation.
1240 //
1241 // data - training dataset
1242 // params - parameters of the CvGBTrees
1243 // weak - array[0..(class_count-1)] of CvSeq
1244 // for storing tree ensembles
1245 // orig_response - original responses of the training set samples
1246 // sum_response - predicitons of the current model on the training dataset.
1247 // this matrix is updated on every iteration.
1248 // sum_response_tmp - predicitons of the model on the training set on the next
1249 // step. On every iteration values of sum_responses_tmp are
1250 // computed via sum_responses values. When the current
1251 // step is complete sum_response values become equal to
1252 // sum_responses_tmp.
1253 // sampleIdx - indices of samples used for training the ensemble.
1254 // CvGBTrees training procedure takes a set of samples
1255 // (train_data) and a set of responses (responses).
1256 // Only pairs (train_data[i], responses[i]), where i is
1257 // in sample_idx are used for training the ensemble.
1258 // subsample_train - indices of samples used for training a single decision
1259 // tree on the current step. This indices are countered
1260 // relatively to the sample_idx, so that pairs
1261 // (train_data[sample_idx[i]], responses[sample_idx[i]])
1262 // are used for training a decision tree.
1263 // Training set is randomly splited
1264 // in two parts (subsample_train and subsample_test)
1265 // on every iteration accordingly to the portion parameter.
1266 // subsample_test - relative indices of samples from the training set,
1267 // which are not used for training a tree on the current
1268 // step.
1269 // missing - mask of the missing values in the training set. This
1270 // matrix has the same size as train_data. 1 - missing
1271 // value, 0 - not a missing value.
1272 // class_labels - output class labels map.
1273 // rng - random number generator. Used for spliting the
1274 // training set.
1275 // class_count - count of output classes.
1276 // class_count == 1 in the case of regression,
1277 // and > 1 in the case of classification.
1278 // delta - Huber loss function parameter.
1279 // base_value - start point of the gradient descent procedure.
1280 // model prediction is
1281 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1282 // f_0 is the base value.
1283
1284
1285
1286 class CvGBTrees : public CvStatModel
1287 {
1288 public:
1289
1290 /*
1291 // DataType: ENUM
1292 // Loss functions implemented in CvGBTrees.
1293 //
1294 // SQUARED_LOSS
1295 // problem: regression
1296 // loss = (x - x')^2
1297 //
1298 // ABSOLUTE_LOSS
1299 // problem: regression
1300 // loss = abs(x - x')
1301 //
1302 // HUBER_LOSS
1303 // problem: regression
1304 // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1305 // 1/2*(x - x')^2, if abs(x - x') <= delta,
1306 // where delta is the alpha-quantile of pseudo responses from
1307 // the training set.
1308 //
1309 // DEVIANCE_LOSS
1310 // problem: classification
1311 //
1312 */
1313 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1314
1315
1316 /*
1317 // Default constructor. Creates a model only (without training).
1318 // Should be followed by one form of the train(...) function.
1319 //
1320 // API
1321 // CvGBTrees();
1322
1323 // INPUT
1324 // OUTPUT
1325 // RESULT
1326 */
1327 CV_WRAP CvGBTrees();
1328
1329
1330 /*
1331 // Full form constructor. Creates a gradient boosting model and does the
1332 // train.
1333 //
1334 // API
1335 // CvGBTrees( const CvMat* trainData, int tflag,
1336 const CvMat* responses, const CvMat* varIdx=0,
1337 const CvMat* sampleIdx=0, const CvMat* varType=0,
1338 const CvMat* missingDataMask=0,
1339 CvGBTreesParams params=CvGBTreesParams() );
1340
1341 // INPUT
1342 // trainData - a set of input feature vectors.
1343 // size of matrix is
1344 // <count of samples> x <variables count>
1345 // or <variables count> x <count of samples>
1346 // depending on the tflag parameter.
1347 // matrix values are float.
1348 // tflag - a flag showing how do samples stored in the
1349 // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1350 // or column by column (tflag=CV_COL_SAMPLE).
1351 // responses - a vector of responses corresponding to the samples
1352 // in trainData.
1353 // varIdx - indices of used variables. zero value means that all
1354 // variables are active.
1355 // sampleIdx - indices of used samples. zero value means that all
1356 // samples from trainData are in the training set.
1357 // varType - vector of <variables count> length. gives every
1358 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1359 // varType = 0 means all variables are numerical.
1360 // missingDataMask - a mask of misiing values in trainData.
1361 // missingDataMask = 0 means that there are no missing
1362 // values.
1363 // params - parameters of GTB algorithm.
1364 // OUTPUT
1365 // RESULT
1366 */
1367 CvGBTrees( const CvMat* trainData, int tflag,
1368 const CvMat* responses, const CvMat* varIdx=0,
1369 const CvMat* sampleIdx=0, const CvMat* varType=0,
1370 const CvMat* missingDataMask=0,
1371 CvGBTreesParams params=CvGBTreesParams() );
1372
1373
1374 /*
1375 // Destructor.
1376 */
1377 virtual ~CvGBTrees();
1378
1379
1380 /*
1381 // Gradient tree boosting model training
1382 //
1383 // API
1384 // virtual bool train( const CvMat* trainData, int tflag,
1385 const CvMat* responses, const CvMat* varIdx=0,
1386 const CvMat* sampleIdx=0, const CvMat* varType=0,
1387 const CvMat* missingDataMask=0,
1388 CvGBTreesParams params=CvGBTreesParams(),
1389 bool update=false );
1390
1391 // INPUT
1392 // trainData - a set of input feature vectors.
1393 // size of matrix is
1394 // <count of samples> x <variables count>
1395 // or <variables count> x <count of samples>
1396 // depending on the tflag parameter.
1397 // matrix values are float.
1398 // tflag - a flag showing how do samples stored in the
1399 // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1400 // or column by column (tflag=CV_COL_SAMPLE).
1401 // responses - a vector of responses corresponding to the samples
1402 // in trainData.
1403 // varIdx - indices of used variables. zero value means that all
1404 // variables are active.
1405 // sampleIdx - indices of used samples. zero value means that all
1406 // samples from trainData are in the training set.
1407 // varType - vector of <variables count> length. gives every
1408 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1409 // varType = 0 means all variables are numerical.
1410 // missingDataMask - a mask of misiing values in trainData.
1411 // missingDataMask = 0 means that there are no missing
1412 // values.
1413 // params - parameters of GTB algorithm.
1414 // update - is not supported now. (!)
1415 // OUTPUT
1416 // RESULT
1417 // Error state.
1418 */
1419 virtual bool train( const CvMat* trainData, int tflag,
1420 const CvMat* responses, const CvMat* varIdx=0,
1421 const CvMat* sampleIdx=0, const CvMat* varType=0,
1422 const CvMat* missingDataMask=0,
1423 CvGBTreesParams params=CvGBTreesParams(),
1424 bool update=false );
1425
1426
1427 /*
1428 // Gradient tree boosting model training
1429 //
1430 // API
1431 // virtual bool train( CvMLData* data,
1432 CvGBTreesParams params=CvGBTreesParams(),
1433 bool update=false ) {return false;}
1434
1435 // INPUT
1436 // data - training set.
1437 // params - parameters of GTB algorithm.
1438 // update - is not supported now. (!)
1439 // OUTPUT
1440 // RESULT
1441 // Error state.
1442 */
1443 virtual bool train( CvMLData* data,
1444 CvGBTreesParams params=CvGBTreesParams(),
1445 bool update=false );
1446
1447
1448 /*
1449 // Response value prediction
1450 //
1451 // API
1452 // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1453 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1454 int k=-1 ) const;
1455
1456 // INPUT
1457 // sample - input sample of the same type as in the training set.
1458 // missing - missing values mask. missing=0 if there are no
1459 // missing values in sample vector.
1460 // weak_responses - predictions of all of the trees.
1461 // not implemented (!)
1462 // slice - part of the ensemble used for prediction.
1463 // slice = CV_WHOLE_SEQ when all trees are used.
1464 // k - number of ensemble used.
1465 // k is in {-1,0,1,..,<count of output classes-1>}.
1466 // in the case of classification problem
1467 // <count of output classes-1> ensembles are built.
1468 // If k = -1 ordinary prediction is the result,
1469 // otherwise function gives the prediction of the
1470 // k-th ensemble only.
1471 // OUTPUT
1472 // RESULT
1473 // Predicted value.
1474 */
1475 virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1476 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1477 int k=-1 ) const;
1478
1479 /*
1480 // Response value prediction.
1481 // Parallel version (in the case of TBB existence)
1482 //
1483 // API
1484 // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1485 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1486 int k=-1 ) const;
1487
1488 // INPUT
1489 // sample - input sample of the same type as in the training set.
1490 // missing - missing values mask. missing=0 if there are no
1491 // missing values in sample vector.
1492 // weak_responses - predictions of all of the trees.
1493 // not implemented (!)
1494 // slice - part of the ensemble used for prediction.
1495 // slice = CV_WHOLE_SEQ when all trees are used.
1496 // k - number of ensemble used.
1497 // k is in {-1,0,1,..,<count of output classes-1>}.
1498 // in the case of classification problem
1499 // <count of output classes-1> ensembles are built.
1500 // If k = -1 ordinary prediction is the result,
1501 // otherwise function gives the prediction of the
1502 // k-th ensemble only.
1503 // OUTPUT
1504 // RESULT
1505 // Predicted value.
1506 */
1507 virtual float predict( const CvMat* sample, const CvMat* missing=0,
1508 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1509 int k=-1 ) const;
1510
1511 /*
1512 // Deletes all the data.
1513 //
1514 // API
1515 // virtual void clear();
1516
1517 // INPUT
1518 // OUTPUT
1519 // delete data, weak, orig_response, sum_response,
1520 // weak_eval, subsample_train, subsample_test,
1521 // sample_idx, missing, lass_labels
1522 // delta = 0.0
1523 // RESULT
1524 */
1525 CV_WRAP virtual void clear();
1526
1527 /*
1528 // Compute error on the train/test set.
1529 //
1530 // API
1531 // virtual float calc_error( CvMLData* _data, int type,
1532 // std::vector<float> *resp = 0 );
1533 //
1534 // INPUT
1535 // data - dataset
1536 // type - defines which error is to compute: train (CV_TRAIN_ERROR) or
1537 // test (CV_TEST_ERROR).
1538 // OUTPUT
1539 // resp - vector of predicitons
1540 // RESULT
1541 // Error value.
1542 */
1543 virtual float calc_error( CvMLData* _data, int type,
1544 std::vector<float> *resp = 0 );
1545
1546 /*
1547 //
1548 // Write parameters of the gtb model and data. Write learned model.
1549 //
1550 // API
1551 // virtual void write( CvFileStorage* fs, const char* name ) const;
1552 //
1553 // INPUT
1554 // fs - file storage to read parameters from.
1555 // name - model name.
1556 // OUTPUT
1557 // RESULT
1558 */
1559 virtual void write( CvFileStorage* fs, const char* name ) const;
1560
1561
1562 /*
1563 //
1564 // Read parameters of the gtb model and data. Read learned model.
1565 //
1566 // API
1567 // virtual void read( CvFileStorage* fs, CvFileNode* node );
1568 //
1569 // INPUT
1570 // fs - file storage to read parameters from.
1571 // node - file node.
1572 // OUTPUT
1573 // RESULT
1574 */
1575 virtual void read( CvFileStorage* fs, CvFileNode* node );
1576
1577
1578 // new-style C++ interface
1579 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1580 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1581 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1582 const cv::Mat& missingDataMask=cv::Mat(),
1583 CvGBTreesParams params=CvGBTreesParams() );
1584
1585 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1586 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1587 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1588 const cv::Mat& missingDataMask=cv::Mat(),
1589 CvGBTreesParams params=CvGBTreesParams(),
1590 bool update=false );
1591
1592 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1593 const cv::Range& slice = cv::Range::all(),
1594 int k=-1 ) const;
1595
1596 protected:
1597
1598 /*
1599 // Compute the gradient vector components.
1600 //
1601 // API
1602 // virtual void find_gradient( const int k = 0);
1603
1604 // INPUT
1605 // k - used for classification problem, determining current
1606 // tree ensemble.
1607 // OUTPUT
1608 // changes components of data->responses
1609 // which correspond to samples used for training
1610 // on the current step.
1611 // RESULT
1612 */
1613 virtual void find_gradient( const int k = 0);
1614
1615
1616 /*
1617 //
1618 // Change values in tree leaves according to the used loss function.
1619 //
1620 // API
1621 // virtual void change_values(CvDTree* tree, const int k = 0);
1622 //
1623 // INPUT
1624 // tree - decision tree to change.
1625 // k - used for classification problem, determining current
1626 // tree ensemble.
1627 // OUTPUT
1628 // changes 'value' fields of the trees' leaves.
1629 // changes sum_response_tmp.
1630 // RESULT
1631 */
1632 virtual void change_values(CvDTree* tree, const int k = 0);
1633
1634
1635 /*
1636 //
1637 // Find optimal constant prediction value according to the used loss
1638 // function.
1639 // The goal is to find a constant which gives the minimal summary loss
1640 // on the _Idx samples.
1641 //
1642 // API
1643 // virtual float find_optimal_value( const CvMat* _Idx );
1644 //
1645 // INPUT
1646 // _Idx - indices of the samples from the training set.
1647 // OUTPUT
1648 // RESULT
1649 // optimal constant value.
1650 */
1651 virtual float find_optimal_value( const CvMat* _Idx );
1652
1653
1654 /*
1655 //
1656 // Randomly split the whole training set in two parts according
1657 // to params.portion.
1658 //
1659 // API
1660 // virtual void do_subsample();
1661 //
1662 // INPUT
1663 // OUTPUT
1664 // subsample_train - indices of samples used for training
1665 // subsample_test - indices of samples used for test
1666 // RESULT
1667 */
1668 virtual void do_subsample();
1669
1670
1671 /*
1672 //
1673 // Internal recursive function giving an array of subtree tree leaves.
1674 //
1675 // API
1676 // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1677 //
1678 // INPUT
1679 // node - current leaf.
1680 // OUTPUT
1681 // count - count of leaves in the subtree.
1682 // leaves - array of pointers to leaves.
1683 // RESULT
1684 */
1685 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1686
1687
1688 /*
1689 //
1690 // Get leaves of the tree.
1691 //
1692 // API
1693 // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1694 //
1695 // INPUT
1696 // dtree - decision tree.
1697 // OUTPUT
1698 // len - count of the leaves.
1699 // RESULT
1700 // CvDTreeNode** - array of pointers to leaves.
1701 */
1702 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1703
1704
1705 /*
1706 //
1707 // Is it a regression or a classification.
1708 //
1709 // API
1710 // bool problem_type();
1711 //
1712 // INPUT
1713 // OUTPUT
1714 // RESULT
1715 // false if it is a classification problem,
1716 // true - if regression.
1717 */
1718 virtual bool problem_type() const;
1719
1720
1721 /*
1722 //
1723 // Write parameters of the gtb model.
1724 //
1725 // API
1726 // virtual void write_params( CvFileStorage* fs ) const;
1727 //
1728 // INPUT
1729 // fs - file storage to write parameters to.
1730 // OUTPUT
1731 // RESULT
1732 */
1733 virtual void write_params( CvFileStorage* fs ) const;
1734
1735
1736 /*
1737 //
1738 // Read parameters of the gtb model and data.
1739 //
1740 // API
1741 // virtual void read_params( CvFileStorage* fs );
1742 //
1743 // INPUT
1744 // fs - file storage to read parameters from.
1745 // OUTPUT
1746 // params - parameters of the gtb model.
1747 // data - contains information about the structure
1748 // of the data set (count of variables,
1749 // their types, etc.).
1750 // class_labels - output class labels map.
1751 // RESULT
1752 */
1753 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1754 int get_len(const CvMat* mat) const;
1755
1756
1757 CvDTreeTrainData* data;
1758 CvGBTreesParams params;
1759
1760 CvSeq** weak;
1761 CvMat* orig_response;
1762 CvMat* sum_response;
1763 CvMat* sum_response_tmp;
1764 CvMat* sample_idx;
1765 CvMat* subsample_train;
1766 CvMat* subsample_test;
1767 CvMat* missing;
1768 CvMat* class_labels;
1769
1770 cv::RNG* rng;
1771
1772 int class_count;
1773 float delta;
1774 float base_value;
1775
1776 };
1777
1778
1779
1780 /****************************************************************************************\
1781 * Artificial Neural Networks (ANN) *
1782 \****************************************************************************************/
1783
1784 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1785
1786 struct CvANN_MLP_TrainParams
1787 {
1788 CvANN_MLP_TrainParams();
1789 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1790 double param1, double param2=0 );
1791 ~CvANN_MLP_TrainParams();
1792
1793 enum { BACKPROP=0, RPROP=1 };
1794
1795 CV_PROP_RW CvTermCriteria term_crit;
1796 CV_PROP_RW int train_method;
1797
1798 // backpropagation parameters
1799 CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1800
1801 // rprop parameters
1802 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1803 };
1804
1805
1806 class CvANN_MLP : public CvStatModel
1807 {
1808 public:
1809 CV_WRAP CvANN_MLP();
1810 CvANN_MLP( const CvMat* layerSizes,
1811 int activateFunc=CvANN_MLP::SIGMOID_SYM,
1812 double fparam1=0, double fparam2=0 );
1813
1814 virtual ~CvANN_MLP();
1815
1816 virtual void create( const CvMat* layerSizes,
1817 int activateFunc=CvANN_MLP::SIGMOID_SYM,
1818 double fparam1=0, double fparam2=0 );
1819
1820 virtual int train( const CvMat* inputs, const CvMat* outputs,
1821 const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1822 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1823 int flags=0 );
1824 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1825
1826 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1827 int activateFunc=CvANN_MLP::SIGMOID_SYM,
1828 double fparam1=0, double fparam2=0 );
1829
1830 CV_WRAP virtual void create( const cv::Mat& layerSizes,
1831 int activateFunc=CvANN_MLP::SIGMOID_SYM,
1832 double fparam1=0, double fparam2=0 );
1833
1834 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1835 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1836 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1837 int flags=0 );
1838
1839 CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1840
1841 CV_WRAP virtual void clear();
1842
1843 // possible activation functions
1844 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1845
1846 // available training flags
1847 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1848
1849 virtual void read( CvFileStorage* fs, CvFileNode* node );
1850 virtual void write( CvFileStorage* storage, const char* name ) const;
1851
get_layer_count()1852 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
get_layer_sizes()1853 const CvMat* get_layer_sizes() { return layer_sizes; }
get_weights(int layer)1854 double* get_weights(int layer)
1855 {
1856 return layer_sizes && weights &&
1857 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1858 }
1859
1860 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1861
1862 protected:
1863
1864 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1865 const CvMat* _sample_weights, const CvMat* sampleIdx,
1866 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1867
1868 // sequential random backpropagation
1869 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1870
1871 // RPROP algorithm
1872 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1873
1874 virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1875 virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1876 double _f_param1=0, double _f_param2=0 );
1877 virtual void init_weights();
1878 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1879 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1880 virtual void calc_input_scale( const CvVectors* vecs, int flags );
1881 virtual void calc_output_scale( const CvVectors* vecs, int flags );
1882
1883 virtual void write_params( CvFileStorage* fs ) const;
1884 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1885
1886 CvMat* layer_sizes;
1887 CvMat* wbuf;
1888 CvMat* sample_weights;
1889 double** weights;
1890 double f_param1, f_param2;
1891 double min_val, max_val, min_val1, max_val1;
1892 int activ_func;
1893 int max_count, max_buf_sz;
1894 CvANN_MLP_TrainParams params;
1895 cv::RNG* rng;
1896 };
1897
1898 /****************************************************************************************\
1899 * Auxilary functions declarations *
1900 \****************************************************************************************/
1901
1902 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1903 average row vector, <cov> - symmetric covariation matrix */
1904 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1905 CvRNG* rng CV_DEFAULT(0) );
1906
1907 /* Generates sample from gaussian mixture distribution */
1908 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1909 CvMat* covs[],
1910 float weights[],
1911 int clsnum,
1912 CvMat* sample,
1913 CvMat* sampClasses CV_DEFAULT(0) );
1914
1915 #define CV_TS_CONCENTRIC_SPHERES 0
1916
1917 /* creates test set */
1918 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1919 int num_samples,
1920 int num_features,
1921 CvMat** responses,
1922 int num_classes, ... );
1923
1924 /****************************************************************************************\
1925 * Data *
1926 \****************************************************************************************/
1927
1928 #define CV_COUNT 0
1929 #define CV_PORTION 1
1930
1931 struct CvTrainTestSplit
1932 {
1933 CvTrainTestSplit();
1934 CvTrainTestSplit( int train_sample_count, bool mix = true);
1935 CvTrainTestSplit( float train_sample_portion, bool mix = true);
1936
1937 union
1938 {
1939 int count;
1940 float portion;
1941 } train_sample_part;
1942 int train_sample_part_mode;
1943
1944 bool mix;
1945 };
1946
1947 class CvMLData
1948 {
1949 public:
1950 CvMLData();
1951 virtual ~CvMLData();
1952
1953 // returns:
1954 // 0 - OK
1955 // -1 - file can not be opened or is not correct
1956 int read_csv( const char* filename );
1957
1958 const CvMat* get_values() const;
1959 const CvMat* get_responses();
1960 const CvMat* get_missing() const;
1961
1962 void set_header_lines_number( int n );
1963 int get_header_lines_number() const;
1964
1965 void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
1966 // if idx < 0 there will be no response
1967 int get_response_idx() const;
1968
1969 void set_train_test_split( const CvTrainTestSplit * spl );
1970 const CvMat* get_train_sample_idx() const;
1971 const CvMat* get_test_sample_idx() const;
1972 void mix_train_and_test_idx();
1973
1974 const CvMat* get_var_idx();
1975 void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
1976 // use change_var_idx
1977 void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
1978
1979 const CvMat* get_var_types();
1980 int get_var_type( int var_idx ) const;
1981 // following 2 methods enable to change vars type
1982 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1983 // with numerical labels; in the other cases var types are correctly determined automatically
1984 void set_var_types( const char* str ); // str examples:
1985 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1986 // "cat", "ord" (all vars are categorical/ordered)
1987 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
1988
1989 void set_delimiter( char ch );
1990 char get_delimiter() const;
1991
1992 void set_miss_ch( char ch );
1993 char get_miss_ch() const;
1994
1995 const std::map<cv::String, int>& get_class_labels_map() const;
1996
1997 protected:
1998 virtual void clear();
1999
2000 void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2001 void free_train_test_idx();
2002
2003 char delimiter;
2004 char miss_ch;
2005 //char flt_separator;
2006
2007 CvMat* values;
2008 CvMat* missing;
2009 CvMat* var_types;
2010 CvMat* var_idx_mask;
2011
2012 CvMat* response_out; // header
2013 CvMat* var_idx_out; // mat
2014 CvMat* var_types_out; // mat
2015
2016 int header_lines_number;
2017
2018 int response_idx;
2019
2020 int train_sample_count;
2021 bool mix;
2022
2023 int total_class_count;
2024 std::map<cv::String, int> class_map;
2025
2026 CvMat* train_sample_idx;
2027 CvMat* test_sample_idx;
2028 int* sample_idx; // data of train_sample_idx and test_sample_idx
2029
2030 cv::RNG* rng;
2031 };
2032
2033
2034 namespace cv
2035 {
2036
2037 typedef CvStatModel StatModel;
2038 typedef CvParamGrid ParamGrid;
2039 typedef CvNormalBayesClassifier NormalBayesClassifier;
2040 typedef CvKNearest KNearest;
2041 typedef CvSVMParams SVMParams;
2042 typedef CvSVMKernel SVMKernel;
2043 typedef CvSVMSolver SVMSolver;
2044 typedef CvSVM SVM;
2045 typedef CvDTreeParams DTreeParams;
2046 typedef CvMLData TrainData;
2047 typedef CvDTree DecisionTree;
2048 typedef CvForestTree ForestTree;
2049 typedef CvRTParams RandomTreeParams;
2050 typedef CvRTrees RandomTrees;
2051 typedef CvERTreeTrainData ERTreeTRainData;
2052 typedef CvForestERTree ERTree;
2053 typedef CvERTrees ERTrees;
2054 typedef CvBoostParams BoostParams;
2055 typedef CvBoostTree BoostTree;
2056 typedef CvBoost Boost;
2057 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
2058 typedef CvANN_MLP NeuralNet_MLP;
2059 typedef CvGBTreesParams GradientBoostingTreeParams;
2060 typedef CvGBTrees GradientBoostingTrees;
2061
2062 template<> void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const;
2063 }
2064
2065 #endif // __cplusplus
2066 #endif // __OPENCV_ML_HPP__
2067
2068 /* End of file. */
2069