1 #ifndef _OPENCV_BOOST_H_
2 #define _OPENCV_BOOST_H_
3 
4 #include "traincascade_features.h"
5 #include "old_ml.hpp"
6 
7 struct CvCascadeBoostParams : CvBoostParams
8 {
9     float minHitRate;
10     float maxFalseAlarm;
11 
12     CvCascadeBoostParams();
13     CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
14                           double _weightTrimRate, int _maxDepth, int _maxWeakCount );
~CvCascadeBoostParamsCvCascadeBoostParams15     virtual ~CvCascadeBoostParams() {}
16     void write( cv::FileStorage &fs ) const;
17     bool read( const cv::FileNode &node );
18     virtual void printDefaults() const;
19     virtual void printAttrs() const;
20     virtual bool scanAttr( const std::string prmName, const std::string val);
21 };
22 
23 struct CvCascadeBoostTrainData : CvDTreeTrainData
24 {
25     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
26                              const CvDTreeParams& _params );
27     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
28                              int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
29                              const CvDTreeParams& _params = CvDTreeParams() );
30     virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
31                           int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
32                           const CvDTreeParams& _params=CvDTreeParams() );
33     void precalculate();
34 
35     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
36 
37     virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
38     virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
39     virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
40 
41     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
42                                   const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
43     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
44     virtual float getVarValue( int vi, int si );
45     virtual void free_train_data();
46 
47     const CvFeatureEvaluator* featureEvaluator;
48     cv::Mat valCache; // precalculated feature values (CV_32FC1)
49     CvMat _resp; // for casting
50     int numPrecalcVal, numPrecalcIdx;
51 };
52 
53 class CvCascadeBoostTree : public CvBoostTree
54 {
55 public:
56     virtual CvDTreeNode* predict( int sampleIdx ) const;
57     void write( cv::FileStorage &fs, const cv::Mat& featureMap );
58     void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
59     void markFeaturesInMap( cv::Mat& featureMap );
60 protected:
61     virtual void split_node_data( CvDTreeNode* n );
62 };
63 
64 class CvCascadeBoost : public CvBoost
65 {
66 public:
67     virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
68                         int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
69                         const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
70     virtual float predict( int sampleIdx, bool returnSum = false ) const;
71 
getThreshold()72     float getThreshold() const { return threshold; }
73     void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
74     bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
75                const CvCascadeBoostParams& _params );
76     void markUsedFeaturesInMap( cv::Mat& featureMap );
77 protected:
78     virtual bool set_params( const CvBoostParams& _params );
79     virtual void update_weights( CvBoostTree* tree );
80     virtual bool isErrDesired();
81 
82     float threshold;
83     float minHitRate, maxFalseAlarm;
84 };
85 
86 #endif
87