1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41 
42 #include "test_precomp.hpp"
43 
44 using namespace cv;
45 using namespace std;
46 
str_to_svm_type(String & str)47 int str_to_svm_type(String& str)
48 {
49     if( !str.compare("C_SVC") )
50         return SVM::C_SVC;
51     if( !str.compare("NU_SVC") )
52         return SVM::NU_SVC;
53     if( !str.compare("ONE_CLASS") )
54         return SVM::ONE_CLASS;
55     if( !str.compare("EPS_SVR") )
56         return SVM::EPS_SVR;
57     if( !str.compare("NU_SVR") )
58         return SVM::NU_SVR;
59     CV_Error( CV_StsBadArg, "incorrect svm type string" );
60     return -1;
61 }
str_to_svm_kernel_type(String & str)62 int str_to_svm_kernel_type( String& str )
63 {
64     if( !str.compare("LINEAR") )
65         return SVM::LINEAR;
66     if( !str.compare("POLY") )
67         return SVM::POLY;
68     if( !str.compare("RBF") )
69         return SVM::RBF;
70     if( !str.compare("SIGMOID") )
71         return SVM::SIGMOID;
72     CV_Error( CV_StsBadArg, "incorrect svm type string" );
73     return -1;
74 }
75 
76 // 4. em
77 // 5. ann
str_to_ann_train_method(String & str)78 int str_to_ann_train_method( String& str )
79 {
80     if( !str.compare("BACKPROP") )
81         return ANN_MLP::BACKPROP;
82     if( !str.compare("RPROP") )
83         return ANN_MLP::RPROP;
84     CV_Error( CV_StsBadArg, "incorrect ann train method string" );
85     return -1;
86 }
87 
ann_check_data(Ptr<TrainData> _data)88 void ann_check_data( Ptr<TrainData> _data )
89 {
90     Mat values = _data->getSamples();
91     Mat var_idx = _data->getVarIdx();
92     int nvars = (int)var_idx.total();
93     if( nvars != 0 && nvars != values.cols )
94         CV_Error( CV_StsBadArg, "var_idx is not supported" );
95     if( !_data->getMissing().empty() )
96         CV_Error( CV_StsBadArg, "missing values are not supported" );
97 }
98 
99 // unroll the categorical responses to binary vectors
ann_get_new_responses(Ptr<TrainData> _data,map<int,int> & cls_map)100 Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
101 {
102     Mat train_sidx = _data->getTrainSampleIdx();
103     int* train_sidx_ptr = train_sidx.ptr<int>();
104     Mat responses = _data->getResponses();
105     int cls_count = 0;
106     // construct cls_map
107     cls_map.clear();
108     int nresponses = (int)responses.total();
109     int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses;
110 
111     for( si = 0; si < n; si++ )
112     {
113         int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
114         int r = cvRound(responses.at<float>(sidx));
115         CV_DbgAssert( fabs(responses.at<float>(sidx) - r) < FLT_EPSILON );
116         map<int,int>::iterator it = cls_map.find(r);
117         if( it == cls_map.end() )
118             cls_map[r] = cls_count++;
119     }
120     Mat new_responses = Mat::zeros( nresponses, cls_count, CV_32F );
121     for( si = 0; si < n; si++ )
122     {
123         int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
124         int r = cvRound(responses.at<float>(sidx));
125         int cidx = cls_map[r];
126         new_responses.at<float>(sidx, cidx) = 1.f;
127     }
128     return new_responses;
129 }
130 
ann_calc_error(Ptr<StatModel> ann,Ptr<TrainData> _data,map<int,int> & cls_map,int type,vector<float> * resp_labels)131 float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
132 {
133     float err = 0;
134     Mat samples = _data->getSamples();
135     Mat responses = _data->getResponses();
136     Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx();
137     int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0;
138     ann_check_data( _data );
139     int sample_count = (int)sample_idx.total();
140     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count;
141     float* pred_resp = 0;
142     vector<float> innresp;
143     if( sample_count > 0 )
144     {
145         if( resp_labels )
146         {
147             resp_labels->resize( sample_count );
148             pred_resp = &((*resp_labels)[0]);
149         }
150         else
151         {
152             innresp.resize( sample_count );
153             pred_resp = &(innresp[0]);
154         }
155     }
156     int cls_count = (int)cls_map.size();
157     Mat output( 1, cls_count, CV_32FC1 );
158 
159     for( int i = 0; i < sample_count; i++ )
160     {
161         int si = sidx ? sidx[i] : i;
162         Mat sample = samples.row(si);
163         ann->predict( sample, output );
164         Point best_cls;
165         minMaxLoc(output, 0, 0, 0, &best_cls, 0);
166         int r = cvRound(responses.at<float>(si));
167         CV_DbgAssert( fabs(responses.at<float>(si) - r) < FLT_EPSILON );
168         r = cls_map[r];
169         int d = best_cls.x == r ? 0 : 1;
170         err += d;
171         pred_resp[i] = (float)best_cls.x;
172     }
173     err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
174     return err;
175 }
176 
177 // 6. dtree
178 // 7. boost
str_to_boost_type(String & str)179 int str_to_boost_type( String& str )
180 {
181     if ( !str.compare("DISCRETE") )
182         return Boost::DISCRETE;
183     if ( !str.compare("REAL") )
184         return Boost::REAL;
185     if ( !str.compare("LOGIT") )
186         return Boost::LOGIT;
187     if ( !str.compare("GENTLE") )
188         return Boost::GENTLE;
189     CV_Error( CV_StsBadArg, "incorrect boost type string" );
190     return -1;
191 }
192 
193 // 8. rtrees
194 // 9. ertrees
195 
196 // ---------------------------------- MLBaseTest ---------------------------------------------------
197 
CV_MLBaseTest(const char * _modelName)198 CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
199 {
200     int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
201                       CV_BIG_INT(0x0000a17166072c7c),
202                       CV_BIG_INT(0x0201b32115cd1f9a),
203                       CV_BIG_INT(0x0513cb37abcd1234),
204                       CV_BIG_INT(0x0001a2b3c4d5f678)
205                     };
206 
207     int seedCount = sizeof(seeds)/sizeof(seeds[0]);
208     RNG& rng = theRNG();
209 
210     initSeed = rng.state;
211     rng.state = seeds[rng(seedCount)];
212 
213     modelName = _modelName;
214 }
215 
~CV_MLBaseTest()216 CV_MLBaseTest::~CV_MLBaseTest()
217 {
218     if( validationFS.isOpened() )
219         validationFS.release();
220     theRNG().state = initSeed;
221 }
222 
read_params(CvFileStorage * __fs)223 int CV_MLBaseTest::read_params( CvFileStorage* __fs )
224 {
225     FileStorage _fs(__fs, false);
226     if( !_fs.isOpened() )
227         test_case_count = -1;
228     else
229     {
230         FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName];
231         test_case_count = (int)fn.size();
232         if( test_case_count <= 0 )
233             test_case_count = -1;
234         if( test_case_count > 0 )
235         {
236             dataSetNames.resize( test_case_count );
237             FileNodeIterator it = fn.begin();
238             for( int i = 0; i < test_case_count; i++, ++it )
239             {
240                 dataSetNames[i] = (string)*it;
241             }
242         }
243     }
244     return cvtest::TS::OK;;
245 }
246 
run(int)247 void CV_MLBaseTest::run( int )
248 {
249     string filename = ts->get_data_path();
250     filename += get_validation_filename();
251     validationFS.open( filename, FileStorage::READ );
252     read_params( *validationFS );
253 
254     int code = cvtest::TS::OK;
255     for (int i = 0; i < test_case_count; i++)
256     {
257         int temp_code = run_test_case( i );
258         if (temp_code == cvtest::TS::OK)
259             temp_code = validate_test_results( i );
260         if (temp_code != cvtest::TS::OK)
261             code = temp_code;
262     }
263     if ( test_case_count <= 0)
264     {
265         ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
266         code = cvtest::TS::FAIL_INVALID_TEST_DATA;
267     }
268     ts->set_failed_test_info( code );
269 }
270 
prepare_test_case(int test_case_idx)271 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
272 {
273     clear();
274 
275     string dataPath = ts->get_data_path();
276     if ( dataPath.empty() )
277     {
278         ts->printf( cvtest::TS::LOG, "data path is empty" );
279         return cvtest::TS::FAIL_INVALID_TEST_DATA;
280     }
281 
282     string dataName = dataSetNames[test_case_idx],
283         filename = dataPath + dataName + ".data";
284 
285     FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
286     CV_DbgAssert( !dataParamsNode.empty() );
287 
288     CV_DbgAssert( !dataParamsNode["LS"].empty() );
289     int trainSampleCount = (int)dataParamsNode["LS"];
290 
291     CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
292     int respIdx = (int)dataParamsNode["resp_idx"];
293 
294     CV_DbgAssert( !dataParamsNode["types"].empty() );
295     String varTypes = (String)dataParamsNode["types"];
296 
297     data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx+1, varTypes);
298     if( data.empty() )
299     {
300         ts->printf( cvtest::TS::LOG, "file %s can not be read\n", filename.c_str() );
301         return cvtest::TS::FAIL_INVALID_TEST_DATA;
302     }
303 
304     data->setTrainTestSplit(trainSampleCount);
305     return cvtest::TS::OK;
306 }
307 
get_validation_filename()308 string& CV_MLBaseTest::get_validation_filename()
309 {
310     return validationFN;
311 }
312 
train(int testCaseIdx)313 int CV_MLBaseTest::train( int testCaseIdx )
314 {
315     bool is_trained = false;
316     FileNode modelParamsNode =
317         validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
318 
319     if( modelName == CV_NBAYES )
320         model = NormalBayesClassifier::create();
321     else if( modelName == CV_KNEAREST )
322     {
323         model = KNearest::create();
324     }
325     else if( modelName == CV_SVM )
326     {
327         String svm_type_str, kernel_type_str;
328         modelParamsNode["svm_type"] >> svm_type_str;
329         modelParamsNode["kernel_type"] >> kernel_type_str;
330         Ptr<SVM> m = SVM::create();
331         m->setType(str_to_svm_type( svm_type_str ));
332         m->setKernel(str_to_svm_kernel_type( kernel_type_str ));
333         m->setDegree(modelParamsNode["degree"]);
334         m->setGamma(modelParamsNode["gamma"]);
335         m->setCoef0(modelParamsNode["coef0"]);
336         m->setC(modelParamsNode["C"]);
337         m->setNu(modelParamsNode["nu"]);
338         m->setP(modelParamsNode["p"]);
339         model = m;
340     }
341     else if( modelName == CV_EM )
342     {
343         assert( 0 );
344     }
345     else if( modelName == CV_ANN )
346     {
347         String train_method_str;
348         double param1, param2;
349         modelParamsNode["train_method"] >> train_method_str;
350         modelParamsNode["param1"] >> param1;
351         modelParamsNode["param2"] >> param2;
352         Mat new_responses = ann_get_new_responses( data, cls_map );
353         // binarize the responses
354         data = TrainData::create(data->getSamples(), data->getLayout(), new_responses,
355                                  data->getVarIdx(), data->getTrainSampleIdx());
356         int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() };
357         Mat layer_sizes( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
358         Ptr<ANN_MLP> m = ANN_MLP::create();
359         m->setLayerSizes(layer_sizes);
360         m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
361         m->setTermCriteria(TermCriteria(TermCriteria::COUNT,300,0.01));
362         m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2);
363         model = m;
364 
365     }
366     else if( modelName == CV_DTREE )
367     {
368         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
369         float REG_ACCURACY = 0;
370         bool USE_SURROGATE = false, IS_PRUNED;
371         modelParamsNode["max_depth"] >> MAX_DEPTH;
372         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
373         //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
374         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
375         modelParamsNode["cv_folds"] >> CV_FOLDS;
376         modelParamsNode["is_pruned"] >> IS_PRUNED;
377 
378         Ptr<DTrees> m = DTrees::create();
379         m->setMaxDepth(MAX_DEPTH);
380         m->setMinSampleCount(MIN_SAMPLE_COUNT);
381         m->setRegressionAccuracy(REG_ACCURACY);
382         m->setUseSurrogates(USE_SURROGATE);
383         m->setMaxCategories(MAX_CATEGORIES);
384         m->setCVFolds(CV_FOLDS);
385         m->setUse1SERule(false);
386         m->setTruncatePrunedTree(IS_PRUNED);
387         m->setPriors(Mat());
388         model = m;
389     }
390     else if( modelName == CV_BOOST )
391     {
392         int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
393         float WEIGHT_TRIM_RATE;
394         bool USE_SURROGATE = false;
395         String typeStr;
396         modelParamsNode["type"] >> typeStr;
397         BOOST_TYPE = str_to_boost_type( typeStr );
398         modelParamsNode["weak_count"] >> WEAK_COUNT;
399         modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
400         modelParamsNode["max_depth"] >> MAX_DEPTH;
401         //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
402 
403         Ptr<Boost> m = Boost::create();
404         m->setBoostType(BOOST_TYPE);
405         m->setWeakCount(WEAK_COUNT);
406         m->setWeightTrimRate(WEIGHT_TRIM_RATE);
407         m->setMaxDepth(MAX_DEPTH);
408         m->setUseSurrogates(USE_SURROGATE);
409         m->setPriors(Mat());
410         model = m;
411     }
412     else if( modelName == CV_RTREES )
413     {
414         int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
415         float REG_ACCURACY = 0, OOB_EPS = 0.0;
416         bool USE_SURROGATE = false, IS_PRUNED;
417         modelParamsNode["max_depth"] >> MAX_DEPTH;
418         modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
419         //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
420         modelParamsNode["max_categories"] >> MAX_CATEGORIES;
421         modelParamsNode["cv_folds"] >> CV_FOLDS;
422         modelParamsNode["is_pruned"] >> IS_PRUNED;
423         modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
424         modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
425 
426         Ptr<RTrees> m = RTrees::create();
427         m->setMaxDepth(MAX_DEPTH);
428         m->setMinSampleCount(MIN_SAMPLE_COUNT);
429         m->setRegressionAccuracy(REG_ACCURACY);
430         m->setUseSurrogates(USE_SURROGATE);
431         m->setMaxCategories(MAX_CATEGORIES);
432         m->setPriors(Mat());
433         m->setCalculateVarImportance(true);
434         m->setActiveVarCount(NACTIVE_VARS);
435         m->setTermCriteria(TermCriteria(TermCriteria::COUNT, MAX_TREES_NUM, OOB_EPS));
436         model = m;
437     }
438 
439     if( !model.empty() )
440         is_trained = model->train(data, 0);
441 
442     if( !is_trained )
443     {
444         ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
445         return cvtest::TS::FAIL_INVALID_OUTPUT;
446     }
447     return cvtest::TS::OK;
448 }
449 
get_test_error(int,vector<float> * resp)450 float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
451 {
452     int type = CV_TEST_ERROR;
453     float err = 0;
454     Mat _resp;
455     if( modelName == CV_EM )
456         assert( 0 );
457     else if( modelName == CV_ANN )
458         err = ann_calc_error( model, data, cls_map, type, resp );
459     else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
460              modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
461         err = model->calcError( data, true, _resp );
462     if( !_resp.empty() && resp )
463         _resp.convertTo(*resp, CV_32F);
464     return err;
465 }
466 
save(const char * filename)467 void CV_MLBaseTest::save( const char* filename )
468 {
469     model->save( filename );
470 }
471 
load(const char * filename)472 void CV_MLBaseTest::load( const char* filename )
473 {
474     if( modelName == CV_NBAYES )
475         model = Algorithm::load<NormalBayesClassifier>( filename );
476     else if( modelName == CV_KNEAREST )
477         model = Algorithm::load<KNearest>( filename );
478     else if( modelName == CV_SVM )
479         model = Algorithm::load<SVM>( filename );
480     else if( modelName == CV_ANN )
481         model = Algorithm::load<ANN_MLP>( filename );
482     else if( modelName == CV_DTREE )
483         model = Algorithm::load<DTrees>( filename );
484     else if( modelName == CV_BOOST )
485         model = Algorithm::load<Boost>( filename );
486     else if( modelName == CV_RTREES )
487         model = Algorithm::load<RTrees>( filename );
488     else
489         CV_Error( CV_StsNotImplemented, "invalid stat model name");
490 }
491 
492 /* End of file. */
493