1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
21 //
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
25 //
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "test_precomp.hpp"
43
44 using namespace cv;
45 using namespace std;
46
str_to_svm_type(String & str)47 int str_to_svm_type(String& str)
48 {
49 if( !str.compare("C_SVC") )
50 return SVM::C_SVC;
51 if( !str.compare("NU_SVC") )
52 return SVM::NU_SVC;
53 if( !str.compare("ONE_CLASS") )
54 return SVM::ONE_CLASS;
55 if( !str.compare("EPS_SVR") )
56 return SVM::EPS_SVR;
57 if( !str.compare("NU_SVR") )
58 return SVM::NU_SVR;
59 CV_Error( CV_StsBadArg, "incorrect svm type string" );
60 return -1;
61 }
str_to_svm_kernel_type(String & str)62 int str_to_svm_kernel_type( String& str )
63 {
64 if( !str.compare("LINEAR") )
65 return SVM::LINEAR;
66 if( !str.compare("POLY") )
67 return SVM::POLY;
68 if( !str.compare("RBF") )
69 return SVM::RBF;
70 if( !str.compare("SIGMOID") )
71 return SVM::SIGMOID;
72 CV_Error( CV_StsBadArg, "incorrect svm type string" );
73 return -1;
74 }
75
76 // 4. em
77 // 5. ann
str_to_ann_train_method(String & str)78 int str_to_ann_train_method( String& str )
79 {
80 if( !str.compare("BACKPROP") )
81 return ANN_MLP::BACKPROP;
82 if( !str.compare("RPROP") )
83 return ANN_MLP::RPROP;
84 CV_Error( CV_StsBadArg, "incorrect ann train method string" );
85 return -1;
86 }
87
ann_check_data(Ptr<TrainData> _data)88 void ann_check_data( Ptr<TrainData> _data )
89 {
90 Mat values = _data->getSamples();
91 Mat var_idx = _data->getVarIdx();
92 int nvars = (int)var_idx.total();
93 if( nvars != 0 && nvars != values.cols )
94 CV_Error( CV_StsBadArg, "var_idx is not supported" );
95 if( !_data->getMissing().empty() )
96 CV_Error( CV_StsBadArg, "missing values are not supported" );
97 }
98
99 // unroll the categorical responses to binary vectors
ann_get_new_responses(Ptr<TrainData> _data,map<int,int> & cls_map)100 Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
101 {
102 Mat train_sidx = _data->getTrainSampleIdx();
103 int* train_sidx_ptr = train_sidx.ptr<int>();
104 Mat responses = _data->getResponses();
105 int cls_count = 0;
106 // construct cls_map
107 cls_map.clear();
108 int nresponses = (int)responses.total();
109 int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses;
110
111 for( si = 0; si < n; si++ )
112 {
113 int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
114 int r = cvRound(responses.at<float>(sidx));
115 CV_DbgAssert( fabs(responses.at<float>(sidx) - r) < FLT_EPSILON );
116 map<int,int>::iterator it = cls_map.find(r);
117 if( it == cls_map.end() )
118 cls_map[r] = cls_count++;
119 }
120 Mat new_responses = Mat::zeros( nresponses, cls_count, CV_32F );
121 for( si = 0; si < n; si++ )
122 {
123 int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
124 int r = cvRound(responses.at<float>(sidx));
125 int cidx = cls_map[r];
126 new_responses.at<float>(sidx, cidx) = 1.f;
127 }
128 return new_responses;
129 }
130
ann_calc_error(Ptr<StatModel> ann,Ptr<TrainData> _data,map<int,int> & cls_map,int type,vector<float> * resp_labels)131 float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
132 {
133 float err = 0;
134 Mat samples = _data->getSamples();
135 Mat responses = _data->getResponses();
136 Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx();
137 int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0;
138 ann_check_data( _data );
139 int sample_count = (int)sample_idx.total();
140 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count;
141 float* pred_resp = 0;
142 vector<float> innresp;
143 if( sample_count > 0 )
144 {
145 if( resp_labels )
146 {
147 resp_labels->resize( sample_count );
148 pred_resp = &((*resp_labels)[0]);
149 }
150 else
151 {
152 innresp.resize( sample_count );
153 pred_resp = &(innresp[0]);
154 }
155 }
156 int cls_count = (int)cls_map.size();
157 Mat output( 1, cls_count, CV_32FC1 );
158
159 for( int i = 0; i < sample_count; i++ )
160 {
161 int si = sidx ? sidx[i] : i;
162 Mat sample = samples.row(si);
163 ann->predict( sample, output );
164 Point best_cls;
165 minMaxLoc(output, 0, 0, 0, &best_cls, 0);
166 int r = cvRound(responses.at<float>(si));
167 CV_DbgAssert( fabs(responses.at<float>(si) - r) < FLT_EPSILON );
168 r = cls_map[r];
169 int d = best_cls.x == r ? 0 : 1;
170 err += d;
171 pred_resp[i] = (float)best_cls.x;
172 }
173 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
174 return err;
175 }
176
177 // 6. dtree
178 // 7. boost
str_to_boost_type(String & str)179 int str_to_boost_type( String& str )
180 {
181 if ( !str.compare("DISCRETE") )
182 return Boost::DISCRETE;
183 if ( !str.compare("REAL") )
184 return Boost::REAL;
185 if ( !str.compare("LOGIT") )
186 return Boost::LOGIT;
187 if ( !str.compare("GENTLE") )
188 return Boost::GENTLE;
189 CV_Error( CV_StsBadArg, "incorrect boost type string" );
190 return -1;
191 }
192
193 // 8. rtrees
194 // 9. ertrees
195
196 // ---------------------------------- MLBaseTest ---------------------------------------------------
197
CV_MLBaseTest(const char * _modelName)198 CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
199 {
200 int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
201 CV_BIG_INT(0x0000a17166072c7c),
202 CV_BIG_INT(0x0201b32115cd1f9a),
203 CV_BIG_INT(0x0513cb37abcd1234),
204 CV_BIG_INT(0x0001a2b3c4d5f678)
205 };
206
207 int seedCount = sizeof(seeds)/sizeof(seeds[0]);
208 RNG& rng = theRNG();
209
210 initSeed = rng.state;
211 rng.state = seeds[rng(seedCount)];
212
213 modelName = _modelName;
214 }
215
~CV_MLBaseTest()216 CV_MLBaseTest::~CV_MLBaseTest()
217 {
218 if( validationFS.isOpened() )
219 validationFS.release();
220 theRNG().state = initSeed;
221 }
222
read_params(CvFileStorage * __fs)223 int CV_MLBaseTest::read_params( CvFileStorage* __fs )
224 {
225 FileStorage _fs(__fs, false);
226 if( !_fs.isOpened() )
227 test_case_count = -1;
228 else
229 {
230 FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName];
231 test_case_count = (int)fn.size();
232 if( test_case_count <= 0 )
233 test_case_count = -1;
234 if( test_case_count > 0 )
235 {
236 dataSetNames.resize( test_case_count );
237 FileNodeIterator it = fn.begin();
238 for( int i = 0; i < test_case_count; i++, ++it )
239 {
240 dataSetNames[i] = (string)*it;
241 }
242 }
243 }
244 return cvtest::TS::OK;;
245 }
246
run(int)247 void CV_MLBaseTest::run( int )
248 {
249 string filename = ts->get_data_path();
250 filename += get_validation_filename();
251 validationFS.open( filename, FileStorage::READ );
252 read_params( *validationFS );
253
254 int code = cvtest::TS::OK;
255 for (int i = 0; i < test_case_count; i++)
256 {
257 int temp_code = run_test_case( i );
258 if (temp_code == cvtest::TS::OK)
259 temp_code = validate_test_results( i );
260 if (temp_code != cvtest::TS::OK)
261 code = temp_code;
262 }
263 if ( test_case_count <= 0)
264 {
265 ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
266 code = cvtest::TS::FAIL_INVALID_TEST_DATA;
267 }
268 ts->set_failed_test_info( code );
269 }
270
prepare_test_case(int test_case_idx)271 int CV_MLBaseTest::prepare_test_case( int test_case_idx )
272 {
273 clear();
274
275 string dataPath = ts->get_data_path();
276 if ( dataPath.empty() )
277 {
278 ts->printf( cvtest::TS::LOG, "data path is empty" );
279 return cvtest::TS::FAIL_INVALID_TEST_DATA;
280 }
281
282 string dataName = dataSetNames[test_case_idx],
283 filename = dataPath + dataName + ".data";
284
285 FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
286 CV_DbgAssert( !dataParamsNode.empty() );
287
288 CV_DbgAssert( !dataParamsNode["LS"].empty() );
289 int trainSampleCount = (int)dataParamsNode["LS"];
290
291 CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
292 int respIdx = (int)dataParamsNode["resp_idx"];
293
294 CV_DbgAssert( !dataParamsNode["types"].empty() );
295 String varTypes = (String)dataParamsNode["types"];
296
297 data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx+1, varTypes);
298 if( data.empty() )
299 {
300 ts->printf( cvtest::TS::LOG, "file %s can not be read\n", filename.c_str() );
301 return cvtest::TS::FAIL_INVALID_TEST_DATA;
302 }
303
304 data->setTrainTestSplit(trainSampleCount);
305 return cvtest::TS::OK;
306 }
307
get_validation_filename()308 string& CV_MLBaseTest::get_validation_filename()
309 {
310 return validationFN;
311 }
312
train(int testCaseIdx)313 int CV_MLBaseTest::train( int testCaseIdx )
314 {
315 bool is_trained = false;
316 FileNode modelParamsNode =
317 validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
318
319 if( modelName == CV_NBAYES )
320 model = NormalBayesClassifier::create();
321 else if( modelName == CV_KNEAREST )
322 {
323 model = KNearest::create();
324 }
325 else if( modelName == CV_SVM )
326 {
327 String svm_type_str, kernel_type_str;
328 modelParamsNode["svm_type"] >> svm_type_str;
329 modelParamsNode["kernel_type"] >> kernel_type_str;
330 Ptr<SVM> m = SVM::create();
331 m->setType(str_to_svm_type( svm_type_str ));
332 m->setKernel(str_to_svm_kernel_type( kernel_type_str ));
333 m->setDegree(modelParamsNode["degree"]);
334 m->setGamma(modelParamsNode["gamma"]);
335 m->setCoef0(modelParamsNode["coef0"]);
336 m->setC(modelParamsNode["C"]);
337 m->setNu(modelParamsNode["nu"]);
338 m->setP(modelParamsNode["p"]);
339 model = m;
340 }
341 else if( modelName == CV_EM )
342 {
343 assert( 0 );
344 }
345 else if( modelName == CV_ANN )
346 {
347 String train_method_str;
348 double param1, param2;
349 modelParamsNode["train_method"] >> train_method_str;
350 modelParamsNode["param1"] >> param1;
351 modelParamsNode["param2"] >> param2;
352 Mat new_responses = ann_get_new_responses( data, cls_map );
353 // binarize the responses
354 data = TrainData::create(data->getSamples(), data->getLayout(), new_responses,
355 data->getVarIdx(), data->getTrainSampleIdx());
356 int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() };
357 Mat layer_sizes( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
358 Ptr<ANN_MLP> m = ANN_MLP::create();
359 m->setLayerSizes(layer_sizes);
360 m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
361 m->setTermCriteria(TermCriteria(TermCriteria::COUNT,300,0.01));
362 m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2);
363 model = m;
364
365 }
366 else if( modelName == CV_DTREE )
367 {
368 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
369 float REG_ACCURACY = 0;
370 bool USE_SURROGATE = false, IS_PRUNED;
371 modelParamsNode["max_depth"] >> MAX_DEPTH;
372 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
373 //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
374 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
375 modelParamsNode["cv_folds"] >> CV_FOLDS;
376 modelParamsNode["is_pruned"] >> IS_PRUNED;
377
378 Ptr<DTrees> m = DTrees::create();
379 m->setMaxDepth(MAX_DEPTH);
380 m->setMinSampleCount(MIN_SAMPLE_COUNT);
381 m->setRegressionAccuracy(REG_ACCURACY);
382 m->setUseSurrogates(USE_SURROGATE);
383 m->setMaxCategories(MAX_CATEGORIES);
384 m->setCVFolds(CV_FOLDS);
385 m->setUse1SERule(false);
386 m->setTruncatePrunedTree(IS_PRUNED);
387 m->setPriors(Mat());
388 model = m;
389 }
390 else if( modelName == CV_BOOST )
391 {
392 int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
393 float WEIGHT_TRIM_RATE;
394 bool USE_SURROGATE = false;
395 String typeStr;
396 modelParamsNode["type"] >> typeStr;
397 BOOST_TYPE = str_to_boost_type( typeStr );
398 modelParamsNode["weak_count"] >> WEAK_COUNT;
399 modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
400 modelParamsNode["max_depth"] >> MAX_DEPTH;
401 //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
402
403 Ptr<Boost> m = Boost::create();
404 m->setBoostType(BOOST_TYPE);
405 m->setWeakCount(WEAK_COUNT);
406 m->setWeightTrimRate(WEIGHT_TRIM_RATE);
407 m->setMaxDepth(MAX_DEPTH);
408 m->setUseSurrogates(USE_SURROGATE);
409 m->setPriors(Mat());
410 model = m;
411 }
412 else if( modelName == CV_RTREES )
413 {
414 int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
415 float REG_ACCURACY = 0, OOB_EPS = 0.0;
416 bool USE_SURROGATE = false, IS_PRUNED;
417 modelParamsNode["max_depth"] >> MAX_DEPTH;
418 modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
419 //modelParamsNode["use_surrogate"] >> USE_SURROGATE;
420 modelParamsNode["max_categories"] >> MAX_CATEGORIES;
421 modelParamsNode["cv_folds"] >> CV_FOLDS;
422 modelParamsNode["is_pruned"] >> IS_PRUNED;
423 modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
424 modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
425
426 Ptr<RTrees> m = RTrees::create();
427 m->setMaxDepth(MAX_DEPTH);
428 m->setMinSampleCount(MIN_SAMPLE_COUNT);
429 m->setRegressionAccuracy(REG_ACCURACY);
430 m->setUseSurrogates(USE_SURROGATE);
431 m->setMaxCategories(MAX_CATEGORIES);
432 m->setPriors(Mat());
433 m->setCalculateVarImportance(true);
434 m->setActiveVarCount(NACTIVE_VARS);
435 m->setTermCriteria(TermCriteria(TermCriteria::COUNT, MAX_TREES_NUM, OOB_EPS));
436 model = m;
437 }
438
439 if( !model.empty() )
440 is_trained = model->train(data, 0);
441
442 if( !is_trained )
443 {
444 ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
445 return cvtest::TS::FAIL_INVALID_OUTPUT;
446 }
447 return cvtest::TS::OK;
448 }
449
get_test_error(int,vector<float> * resp)450 float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
451 {
452 int type = CV_TEST_ERROR;
453 float err = 0;
454 Mat _resp;
455 if( modelName == CV_EM )
456 assert( 0 );
457 else if( modelName == CV_ANN )
458 err = ann_calc_error( model, data, cls_map, type, resp );
459 else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
460 modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
461 err = model->calcError( data, true, _resp );
462 if( !_resp.empty() && resp )
463 _resp.convertTo(*resp, CV_32F);
464 return err;
465 }
466
save(const char * filename)467 void CV_MLBaseTest::save( const char* filename )
468 {
469 model->save( filename );
470 }
471
load(const char * filename)472 void CV_MLBaseTest::load( const char* filename )
473 {
474 if( modelName == CV_NBAYES )
475 model = Algorithm::load<NormalBayesClassifier>( filename );
476 else if( modelName == CV_KNEAREST )
477 model = Algorithm::load<KNearest>( filename );
478 else if( modelName == CV_SVM )
479 model = Algorithm::load<SVM>( filename );
480 else if( modelName == CV_ANN )
481 model = Algorithm::load<ANN_MLP>( filename );
482 else if( modelName == CV_DTREE )
483 model = Algorithm::load<DTrees>( filename );
484 else if( modelName == CV_BOOST )
485 model = Algorithm::load<Boost>( filename );
486 else if( modelName == CV_RTREES )
487 model = Algorithm::load<RTrees>( filename );
488 else
489 CV_Error( CV_StsNotImplemented, "invalid stat model name");
490 }
491
492 /* End of file. */
493