1 #ifdef __GNUC__
2 #  pragma GCC diagnostic ignored "-Wmissing-declarations"
3 #  if defined __clang__ || defined __APPLE__
4 #    pragma GCC diagnostic ignored "-Wmissing-prototypes"
5 #    pragma GCC diagnostic ignored "-Wextra"
6 #  endif
7 #endif
8 
9 #ifndef __OPENCV_TEST_PRECOMP_HPP__
10 #define __OPENCV_TEST_PRECOMP_HPP__
11 
12 #include <iostream>
13 #include <map>
14 #include "opencv2/ts.hpp"
15 #include "opencv2/ml.hpp"
16 #include "opencv2/core/core_c.h"
17 
18 #define CV_NBAYES   "nbayes"
19 #define CV_KNEAREST "knearest"
20 #define CV_SVM      "svm"
21 #define CV_EM       "em"
22 #define CV_ANN      "ann"
23 #define CV_DTREE    "dtree"
24 #define CV_BOOST    "boost"
25 #define CV_RTREES   "rtrees"
26 #define CV_ERTREES  "ertrees"
27 
28 enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
29 
30 using cv::Ptr;
31 using cv::ml::StatModel;
32 using cv::ml::TrainData;
33 using cv::ml::NormalBayesClassifier;
34 using cv::ml::SVM;
35 using cv::ml::KNearest;
36 using cv::ml::ParamGrid;
37 using cv::ml::ANN_MLP;
38 using cv::ml::DTrees;
39 using cv::ml::Boost;
40 using cv::ml::RTrees;
41 
42 class CV_MLBaseTest : public cvtest::BaseTest
43 {
44 public:
45     CV_MLBaseTest( const char* _modelName );
46     virtual ~CV_MLBaseTest();
47 protected:
48     virtual int read_params( CvFileStorage* fs );
49     virtual void run( int startFrom );
50     virtual int prepare_test_case( int testCaseIdx );
51     virtual std::string& get_validation_filename();
52     virtual int run_test_case( int testCaseIdx ) = 0;
53     virtual int validate_test_results( int testCaseIdx ) = 0;
54 
55     int train( int testCaseIdx );
56     float get_test_error( int testCaseIdx, std::vector<float> *resp = 0 );
57     void save( const char* filename );
58     void load( const char* filename );
59 
60     Ptr<TrainData> data;
61     std::string modelName, validationFN;
62     std::vector<std::string> dataSetNames;
63     cv::FileStorage validationFS;
64 
65     Ptr<StatModel> model;
66 
67     std::map<int, int> cls_map;
68 
69     int64 initSeed;
70 };
71 
72 class CV_AMLTest : public CV_MLBaseTest
73 {
74 public:
75     CV_AMLTest( const char* _modelName );
~CV_AMLTest()76     virtual ~CV_AMLTest() {}
77 protected:
78     virtual int run_test_case( int testCaseIdx );
79     virtual int validate_test_results( int testCaseIdx );
80 };
81 
82 class CV_SLMLTest : public CV_MLBaseTest
83 {
84 public:
85     CV_SLMLTest( const char* _modelName );
~CV_SLMLTest()86     virtual ~CV_SLMLTest() {}
87 protected:
88     virtual int run_test_case( int testCaseIdx );
89     virtual int validate_test_results( int testCaseIdx );
90 
91     std::vector<float> test_resps1, test_resps2; // predicted responses for test data
92     std::string fname1, fname2;
93 };
94 
95 #endif
96