1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 
45 #include <stdarg.h>
46 #include <ctype.h>
47 
48 /****************************************************************************************\
49                                 COPYRIGHT NOTICE
50                                 ----------------
51 
52   The code has been derived from libsvm library (version 2.6)
53   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
54 
55   Here is the orignal copyright:
56 ------------------------------------------------------------------------------------------
57     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
58     All rights reserved.
59 
60     Redistribution and use in source and binary forms, with or without
61     modification, are permitted provided that the following conditions
62     are met:
63 
64     1. Redistributions of source code must retain the above copyright
65     notice, this list of conditions and the following disclaimer.
66 
67     2. Redistributions in binary form must reproduce the above copyright
68     notice, this list of conditions and the following disclaimer in the
69     documentation and/or other materials provided with the distribution.
70 
71     3. Neither name of copyright holders nor the names of its contributors
72     may be used to endorse or promote products derived from this software
73     without specific prior written permission.
74 
75 
76     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
77     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
78     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
79     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
80     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
81     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
82     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
83     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
84     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
85     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
86     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
87 \****************************************************************************************/
88 
89 namespace cv { namespace ml {
90 
91 typedef float Qfloat;
92 const int QFLOAT_TYPE = DataDepth<Qfloat>::value;
93 
94 // Param Grid
checkParamGrid(const ParamGrid & pg)95 static void checkParamGrid(const ParamGrid& pg)
96 {
97     if( pg.minVal > pg.maxVal )
98         CV_Error( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
99     if( pg.minVal < DBL_EPSILON )
100         CV_Error( CV_StsBadArg, "Lower bound of the grid must be positive" );
101     if( pg.logStep < 1. + FLT_EPSILON )
102         CV_Error( CV_StsBadArg, "Grid step must greater then 1" );
103 }
104 
105 // SVM training parameters
106 struct SvmParams
107 {
108     int         svmType;
109     int         kernelType;
110     double      gamma;
111     double      coef0;
112     double      degree;
113     double      C;
114     double      nu;
115     double      p;
116     Mat         classWeights;
117     TermCriteria termCrit;
118 
SvmParamscv::ml::SvmParams119     SvmParams()
120     {
121         svmType = SVM::C_SVC;
122         kernelType = SVM::RBF;
123         degree = 0;
124         gamma = 1;
125         coef0 = 0;
126         C = 1;
127         nu = 0;
128         p = 0;
129         termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
130     }
131 
SvmParamscv::ml::SvmParams132     SvmParams( int _svmType, int _kernelType,
133             double _degree, double _gamma, double _coef0,
134             double _Con, double _nu, double _p,
135             const Mat& _classWeights, TermCriteria _termCrit )
136     {
137         svmType = _svmType;
138         kernelType = _kernelType;
139         degree = _degree;
140         gamma = _gamma;
141         coef0 = _coef0;
142         C = _Con;
143         nu = _nu;
144         p = _p;
145         classWeights = _classWeights;
146         termCrit = _termCrit;
147     }
148 
149 };
150 
151 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
152 class SVMKernelImpl : public SVM::Kernel
153 {
154 public:
SVMKernelImpl(const SvmParams & _params=SvmParams ())155     SVMKernelImpl( const SvmParams& _params = SvmParams() )
156     {
157         params = _params;
158     }
159 
getType() const160     int getType() const
161     {
162         return params.kernelType;
163     }
164 
calc_non_rbf_base(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results,double alpha,double beta)165     void calc_non_rbf_base( int vcount, int var_count, const float* vecs,
166                             const float* another, Qfloat* results,
167                             double alpha, double beta )
168     {
169         int j, k;
170         for( j = 0; j < vcount; j++ )
171         {
172             const float* sample = &vecs[j*var_count];
173             double s = 0;
174             for( k = 0; k <= var_count - 4; k += 4 )
175                 s += sample[k]*another[k] + sample[k+1]*another[k+1] +
176                 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
177             for( ; k < var_count; k++ )
178                 s += sample[k]*another[k];
179             results[j] = (Qfloat)(s*alpha + beta);
180         }
181     }
182 
calc_linear(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)183     void calc_linear( int vcount, int var_count, const float* vecs,
184                       const float* another, Qfloat* results )
185     {
186         calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
187     }
188 
calc_poly(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)189     void calc_poly( int vcount, int var_count, const float* vecs,
190                     const float* another, Qfloat* results )
191     {
192         Mat R( 1, vcount, QFLOAT_TYPE, results );
193         calc_non_rbf_base( vcount, var_count, vecs, another, results, params.gamma, params.coef0 );
194         if( vcount > 0 )
195             pow( R, params.degree, R );
196     }
197 
calc_sigmoid(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)198     void calc_sigmoid( int vcount, int var_count, const float* vecs,
199                        const float* another, Qfloat* results )
200     {
201         int j;
202         calc_non_rbf_base( vcount, var_count, vecs, another, results,
203                           -2*params.gamma, -2*params.coef0 );
204         // TODO: speedup this
205         for( j = 0; j < vcount; j++ )
206         {
207             Qfloat t = results[j];
208             Qfloat e = std::exp(-std::abs(t));
209             if( t > 0 )
210                 results[j] = (Qfloat)((1. - e)/(1. + e));
211             else
212                 results[j] = (Qfloat)((e - 1.)/(e + 1.));
213         }
214     }
215 
216 
calc_rbf(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)217     void calc_rbf( int vcount, int var_count, const float* vecs,
218                    const float* another, Qfloat* results )
219     {
220         double gamma = -params.gamma;
221         int j, k;
222 
223         for( j = 0; j < vcount; j++ )
224         {
225             const float* sample = &vecs[j*var_count];
226             double s = 0;
227 
228             for( k = 0; k <= var_count - 4; k += 4 )
229             {
230                 double t0 = sample[k] - another[k];
231                 double t1 = sample[k+1] - another[k+1];
232 
233                 s += t0*t0 + t1*t1;
234 
235                 t0 = sample[k+2] - another[k+2];
236                 t1 = sample[k+3] - another[k+3];
237 
238                 s += t0*t0 + t1*t1;
239             }
240 
241             for( ; k < var_count; k++ )
242             {
243                 double t0 = sample[k] - another[k];
244                 s += t0*t0;
245             }
246             results[j] = (Qfloat)(s*gamma);
247         }
248 
249         if( vcount > 0 )
250         {
251             Mat R( 1, vcount, QFLOAT_TYPE, results );
252             exp( R, R );
253         }
254     }
255 
256     /// Histogram intersection kernel
calc_intersec(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)257     void calc_intersec( int vcount, int var_count, const float* vecs,
258                         const float* another, Qfloat* results )
259     {
260         int j, k;
261         for( j = 0; j < vcount; j++ )
262         {
263             const float* sample = &vecs[j*var_count];
264             double s = 0;
265             for( k = 0; k <= var_count - 4; k += 4 )
266                 s += std::min(sample[k],another[k]) + std::min(sample[k+1],another[k+1]) +
267                 std::min(sample[k+2],another[k+2]) + std::min(sample[k+3],another[k+3]);
268             for( ; k < var_count; k++ )
269                 s += std::min(sample[k],another[k]);
270             results[j] = (Qfloat)(s);
271         }
272     }
273 
274     /// Exponential chi2 kernel
calc_chi2(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)275     void calc_chi2( int vcount, int var_count, const float* vecs,
276                     const float* another, Qfloat* results )
277     {
278         Mat R( 1, vcount, QFLOAT_TYPE, results );
279         double gamma = -params.gamma;
280         int j, k;
281         for( j = 0; j < vcount; j++ )
282         {
283             const float* sample = &vecs[j*var_count];
284             double chi2 = 0;
285             for(k = 0 ; k < var_count; k++ )
286             {
287                 double d = sample[k]-another[k];
288                 double devisor = sample[k]+another[k];
289                 /// if devisor == 0, the Chi2 distance would be zero,
290                 // but calculation would rise an error because of deviding by zero
291                 if (devisor != 0)
292                 {
293                     chi2 += d*d/devisor;
294                 }
295             }
296             results[j] = (Qfloat) (gamma*chi2);
297         }
298         if( vcount > 0 )
299             exp( R, R );
300     }
301 
calc(int vcount,int var_count,const float * vecs,const float * another,Qfloat * results)302     void calc( int vcount, int var_count, const float* vecs,
303                const float* another, Qfloat* results )
304     {
305         switch( params.kernelType )
306         {
307         case SVM::LINEAR:
308             calc_linear(vcount, var_count, vecs, another, results);
309             break;
310         case SVM::RBF:
311             calc_rbf(vcount, var_count, vecs, another, results);
312             break;
313         case SVM::POLY:
314             calc_poly(vcount, var_count, vecs, another, results);
315             break;
316         case SVM::SIGMOID:
317             calc_sigmoid(vcount, var_count, vecs, another, results);
318             break;
319         case SVM::CHI2:
320             calc_chi2(vcount, var_count, vecs, another, results);
321             break;
322         case SVM::INTER:
323             calc_intersec(vcount, var_count, vecs, another, results);
324             break;
325         default:
326             CV_Error(CV_StsBadArg, "Unknown kernel type");
327         }
328         const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
329         for( int j = 0; j < vcount; j++ )
330         {
331             if( results[j] > max_val )
332                 results[j] = max_val;
333         }
334     }
335 
336     SvmParams params;
337 };
338 
339 
340 
341 /////////////////////////////////////////////////////////////////////////
342 
sortSamplesByClasses(const Mat & _samples,const Mat & _responses,vector<int> & sidx_all,vector<int> & class_ranges)343 static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
344                            vector<int>& sidx_all, vector<int>& class_ranges )
345 {
346     int i, nsamples = _samples.rows;
347     CV_Assert( _responses.isContinuous() && _responses.checkVector(1, CV_32S) == nsamples );
348 
349     setRangeVector(sidx_all, nsamples);
350 
351     const int* rptr = _responses.ptr<int>();
352     std::sort(sidx_all.begin(), sidx_all.end(), cmp_lt_idx<int>(rptr));
353     class_ranges.clear();
354     class_ranges.push_back(0);
355 
356     for( i = 0; i < nsamples; i++ )
357     {
358         if( i == nsamples-1 || rptr[sidx_all[i]] != rptr[sidx_all[i+1]] )
359             class_ranges.push_back(i+1);
360     }
361 }
362 
363 //////////////////////// SVM implementation //////////////////////////////
364 
getDefaultGrid(int param_id)365 ParamGrid SVM::getDefaultGrid( int param_id )
366 {
367     ParamGrid grid;
368     if( param_id == SVM::C )
369     {
370         grid.minVal = 0.1;
371         grid.maxVal = 500;
372         grid.logStep = 5; // total iterations = 5
373     }
374     else if( param_id == SVM::GAMMA )
375     {
376         grid.minVal = 1e-5;
377         grid.maxVal = 0.6;
378         grid.logStep = 15; // total iterations = 4
379     }
380     else if( param_id == SVM::P )
381     {
382         grid.minVal = 0.01;
383         grid.maxVal = 100;
384         grid.logStep = 7; // total iterations = 4
385     }
386     else if( param_id == SVM::NU )
387     {
388         grid.minVal = 0.01;
389         grid.maxVal = 0.2;
390         grid.logStep = 3; // total iterations = 3
391     }
392     else if( param_id == SVM::COEF )
393     {
394         grid.minVal = 0.1;
395         grid.maxVal = 300;
396         grid.logStep = 14; // total iterations = 3
397     }
398     else if( param_id == SVM::DEGREE )
399     {
400         grid.minVal = 0.01;
401         grid.maxVal = 4;
402         grid.logStep = 7; // total iterations = 3
403     }
404     else
405         cvError( CV_StsBadArg, "SVM::getDefaultGrid", "Invalid type of parameter "
406                 "(use one of SVM::C, SVM::GAMMA et al.)", __FILE__, __LINE__ );
407     return grid;
408 }
409 
410 
411 class SVMImpl : public SVM
412 {
413 public:
414     struct DecisionFunc
415     {
DecisionFunccv::ml::SVMImpl::DecisionFunc416         DecisionFunc(double _rho, int _ofs) : rho(_rho), ofs(_ofs) {}
DecisionFunccv::ml::SVMImpl::DecisionFunc417         DecisionFunc() : rho(0.), ofs(0) {}
418         double rho;
419         int ofs;
420     };
421 
422     // Generalized SMO+SVMlight algorithm
423     // Solves:
424     //
425     //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
426     //
427     //      y^T \alpha = \delta
428     //      y_i = +1 or -1
429     //      0 <= alpha_i <= Cp for y_i = 1
430     //      0 <= alpha_i <= Cn for y_i = -1
431     //
432     // Given:
433     //
434     //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
435     //  l is the size of vectors and matrices
436     //  eps is the stopping criterion
437     //
438     // solution will be put in \alpha, objective value will be put in obj
439     //
440     class Solver
441     {
442     public:
443         enum { MIN_CACHE_SIZE = (40 << 20) /* 40Mb */, MAX_CACHE_SIZE = (500 << 20) /* 500Mb */ };
444 
445         typedef bool (Solver::*SelectWorkingSet)( int& i, int& j );
446         typedef Qfloat* (Solver::*GetRow)( int i, Qfloat* row, Qfloat* dst, bool existed );
447         typedef void (Solver::*CalcRho)( double& rho, double& r );
448 
449         struct KernelRow
450         {
KernelRowcv::ml::SVMImpl::Solver::KernelRow451             KernelRow() { idx = -1; prev = next = 0; }
KernelRowcv::ml::SVMImpl::Solver::KernelRow452             KernelRow(int _idx, int _prev, int _next) : idx(_idx), prev(_prev), next(_next) {}
453             int idx;
454             int prev;
455             int next;
456         };
457 
458         struct SolutionInfo
459         {
SolutionInfocv::ml::SVMImpl::Solver::SolutionInfo460             SolutionInfo() { obj = rho = upper_bound_p = upper_bound_n = r = 0; }
461             double obj;
462             double rho;
463             double upper_bound_p;
464             double upper_bound_n;
465             double r;   // for Solver_NU
466         };
467 
clear()468         void clear()
469         {
470             alpha_vec = 0;
471             select_working_set_func = 0;
472             calc_rho_func = 0;
473             get_row_func = 0;
474             lru_cache.clear();
475         }
476 
Solver(const Mat & _samples,const vector<schar> & _y,vector<double> & _alpha,const vector<double> & _b,double _Cp,double _Cn,const Ptr<SVM::Kernel> & _kernel,GetRow _get_row,SelectWorkingSet _select_working_set,CalcRho _calc_rho,TermCriteria _termCrit)477         Solver( const Mat& _samples, const vector<schar>& _y,
478                 vector<double>& _alpha, const vector<double>& _b,
479                 double _Cp, double _Cn,
480                 const Ptr<SVM::Kernel>& _kernel, GetRow _get_row,
481                 SelectWorkingSet _select_working_set, CalcRho _calc_rho,
482                 TermCriteria _termCrit )
483         {
484             clear();
485 
486             samples = _samples;
487             sample_count = samples.rows;
488             var_count = samples.cols;
489 
490             y_vec = _y;
491             alpha_vec = &_alpha;
492             alpha_count = (int)alpha_vec->size();
493             b_vec = _b;
494             kernel = _kernel;
495 
496             C[0] = _Cn;
497             C[1] = _Cp;
498             eps = _termCrit.epsilon;
499             max_iter = _termCrit.maxCount;
500 
501             G_vec.resize(alpha_count);
502             alpha_status_vec.resize(alpha_count);
503             buf[0].resize(sample_count*2);
504             buf[1].resize(sample_count*2);
505 
506             select_working_set_func = _select_working_set;
507             CV_Assert(select_working_set_func != 0);
508 
509             calc_rho_func = _calc_rho;
510             CV_Assert(calc_rho_func != 0);
511 
512             get_row_func = _get_row;
513             CV_Assert(get_row_func != 0);
514 
515             // assume that for large training sets ~25% of Q matrix is used
516             int64 csize = (int64)sample_count*sample_count/4;
517             csize = std::max(csize, (int64)(MIN_CACHE_SIZE/sizeof(Qfloat)) );
518             csize = std::min(csize, (int64)(MAX_CACHE_SIZE/sizeof(Qfloat)) );
519             max_cache_size = (int)((csize + sample_count-1)/sample_count);
520             max_cache_size = std::min(std::max(max_cache_size, 1), sample_count);
521             cache_size = 0;
522 
523             lru_cache.clear();
524             lru_cache.resize(sample_count+1, KernelRow(-1, 0, 0));
525             lru_first = lru_last = 0;
526             lru_cache_data.create(max_cache_size, sample_count, QFLOAT_TYPE);
527         }
528 
get_row_base(int i,bool * _existed)529         Qfloat* get_row_base( int i, bool* _existed )
530         {
531             int i1 = i < sample_count ? i : i - sample_count;
532             KernelRow& kr = lru_cache[i1+1];
533             if( _existed )
534                 *_existed = kr.idx >= 0;
535             if( kr.idx < 0 )
536             {
537                 if( cache_size < max_cache_size )
538                 {
539                     kr.idx = cache_size;
540                     cache_size++;
541                     if (!lru_last)
542                         lru_last = i1+1;
543                 }
544                 else
545                 {
546                     KernelRow& last = lru_cache[lru_last];
547                     kr.idx = last.idx;
548                     last.idx = -1;
549                     lru_cache[last.prev].next = 0;
550                     lru_last = last.prev;
551                     last.prev = 0;
552                     last.next = 0;
553                 }
554                 kernel->calc( sample_count, var_count, samples.ptr<float>(),
555                               samples.ptr<float>(i1), lru_cache_data.ptr<Qfloat>(kr.idx) );
556             }
557             else
558             {
559                 if( kr.next )
560                     lru_cache[kr.next].prev = kr.prev;
561                 else
562                     lru_last = kr.prev;
563                 if( kr.prev )
564                     lru_cache[kr.prev].next = kr.next;
565                 else
566                     lru_first = kr.next;
567             }
568             if (lru_first)
569                 lru_cache[lru_first].prev = i1+1;
570             kr.next = lru_first;
571             kr.prev = 0;
572             lru_first = i1+1;
573 
574             return lru_cache_data.ptr<Qfloat>(kr.idx);
575         }
576 
get_row_svc(int i,Qfloat * row,Qfloat *,bool existed)577         Qfloat* get_row_svc( int i, Qfloat* row, Qfloat*, bool existed )
578         {
579             if( !existed )
580             {
581                 const schar* _y = &y_vec[0];
582                 int j, len = sample_count;
583 
584                 if( _y[i] > 0 )
585                 {
586                     for( j = 0; j < len; j++ )
587                         row[j] = _y[j]*row[j];
588                 }
589                 else
590                 {
591                     for( j = 0; j < len; j++ )
592                         row[j] = -_y[j]*row[j];
593                 }
594             }
595             return row;
596         }
597 
get_row_one_class(int,Qfloat * row,Qfloat *,bool)598         Qfloat* get_row_one_class( int, Qfloat* row, Qfloat*, bool )
599         {
600             return row;
601         }
602 
get_row_svr(int i,Qfloat * row,Qfloat * dst,bool)603         Qfloat* get_row_svr( int i, Qfloat* row, Qfloat* dst, bool )
604         {
605             int j, len = sample_count;
606             Qfloat* dst_pos = dst;
607             Qfloat* dst_neg = dst + len;
608             if( i >= len )
609                 std::swap(dst_pos, dst_neg);
610 
611             for( j = 0; j < len; j++ )
612             {
613                 Qfloat t = row[j];
614                 dst_pos[j] = t;
615                 dst_neg[j] = -t;
616             }
617             return dst;
618         }
619 
get_row(int i,float * dst)620         Qfloat* get_row( int i, float* dst )
621         {
622             bool existed = false;
623             float* row = get_row_base( i, &existed );
624             return (this->*get_row_func)( i, row, dst, existed );
625         }
626 
627         #undef is_upper_bound
628         #define is_upper_bound(i) (alpha_status[i] > 0)
629 
630         #undef is_lower_bound
631         #define is_lower_bound(i) (alpha_status[i] < 0)
632 
633         #undef is_free
634         #define is_free(i) (alpha_status[i] == 0)
635 
636         #undef get_C
637         #define get_C(i) (C[y[i]>0])
638 
639         #undef update_alpha_status
640         #define update_alpha_status(i) \
641             alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
642 
643         #undef reconstruct_gradient
644         #define reconstruct_gradient() /* empty for now */
645 
solve_generic(SolutionInfo & si)646         bool solve_generic( SolutionInfo& si )
647         {
648             const schar* y = &y_vec[0];
649             double* alpha = &alpha_vec->at(0);
650             schar* alpha_status = &alpha_status_vec[0];
651             double* G = &G_vec[0];
652             double* b = &b_vec[0];
653 
654             int iter = 0;
655             int i, j, k;
656 
657             // 1. initialize gradient and alpha status
658             for( i = 0; i < alpha_count; i++ )
659             {
660                 update_alpha_status(i);
661                 G[i] = b[i];
662                 if( fabs(G[i]) > 1e200 )
663                     return false;
664             }
665 
666             for( i = 0; i < alpha_count; i++ )
667             {
668                 if( !is_lower_bound(i) )
669                 {
670                     const Qfloat *Q_i = get_row( i, &buf[0][0] );
671                     double alpha_i = alpha[i];
672 
673                     for( j = 0; j < alpha_count; j++ )
674                         G[j] += alpha_i*Q_i[j];
675                 }
676             }
677 
678             // 2. optimization loop
679             for(;;)
680             {
681                 const Qfloat *Q_i, *Q_j;
682                 double C_i, C_j;
683                 double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
684                 double delta_alpha_i, delta_alpha_j;
685 
686         #ifdef _DEBUG
687                 for( i = 0; i < alpha_count; i++ )
688                 {
689                     if( fabs(G[i]) > 1e+300 )
690                         return false;
691 
692                     if( fabs(alpha[i]) > 1e16 )
693                         return false;
694                 }
695         #endif
696 
697                 if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
698                     break;
699 
700                 Q_i = get_row( i, &buf[0][0] );
701                 Q_j = get_row( j, &buf[1][0] );
702 
703                 C_i = get_C(i);
704                 C_j = get_C(j);
705 
706                 alpha_i = old_alpha_i = alpha[i];
707                 alpha_j = old_alpha_j = alpha[j];
708 
709                 if( y[i] != y[j] )
710                 {
711                     double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
712                     double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
713                     double diff = alpha_i - alpha_j;
714                     alpha_i += delta;
715                     alpha_j += delta;
716 
717                     if( diff > 0 && alpha_j < 0 )
718                     {
719                         alpha_j = 0;
720                         alpha_i = diff;
721                     }
722                     else if( diff <= 0 && alpha_i < 0 )
723                     {
724                         alpha_i = 0;
725                         alpha_j = -diff;
726                     }
727 
728                     if( diff > C_i - C_j && alpha_i > C_i )
729                     {
730                         alpha_i = C_i;
731                         alpha_j = C_i - diff;
732                     }
733                     else if( diff <= C_i - C_j && alpha_j > C_j )
734                     {
735                         alpha_j = C_j;
736                         alpha_i = C_j + diff;
737                     }
738                 }
739                 else
740                 {
741                     double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
742                     double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
743                     double sum = alpha_i + alpha_j;
744                     alpha_i -= delta;
745                     alpha_j += delta;
746 
747                     if( sum > C_i && alpha_i > C_i )
748                     {
749                         alpha_i = C_i;
750                         alpha_j = sum - C_i;
751                     }
752                     else if( sum <= C_i && alpha_j < 0)
753                     {
754                         alpha_j = 0;
755                         alpha_i = sum;
756                     }
757 
758                     if( sum > C_j && alpha_j > C_j )
759                     {
760                         alpha_j = C_j;
761                         alpha_i = sum - C_j;
762                     }
763                     else if( sum <= C_j && alpha_i < 0 )
764                     {
765                         alpha_i = 0;
766                         alpha_j = sum;
767                     }
768                 }
769 
770                 // update alpha
771                 alpha[i] = alpha_i;
772                 alpha[j] = alpha_j;
773                 update_alpha_status(i);
774                 update_alpha_status(j);
775 
776                 // update G
777                 delta_alpha_i = alpha_i - old_alpha_i;
778                 delta_alpha_j = alpha_j - old_alpha_j;
779 
780                 for( k = 0; k < alpha_count; k++ )
781                     G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
782             }
783 
784             // calculate rho
785             (this->*calc_rho_func)( si.rho, si.r );
786 
787             // calculate objective value
788             for( i = 0, si.obj = 0; i < alpha_count; i++ )
789                 si.obj += alpha[i] * (G[i] + b[i]);
790 
791             si.obj *= 0.5;
792 
793             si.upper_bound_p = C[1];
794             si.upper_bound_n = C[0];
795 
796             return true;
797         }
798 
799         // return 1 if already optimal, return 0 otherwise
select_working_set(int & out_i,int & out_j)800         bool select_working_set( int& out_i, int& out_j )
801         {
802             // return i,j which maximize -grad(f)^T d , under constraint
803             // if alpha_i == C, d != +1
804             // if alpha_i == 0, d != -1
805             double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
806             int Gmax1_idx = -1;
807 
808             double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
809             int Gmax2_idx = -1;
810 
811             const schar* y = &y_vec[0];
812             const schar* alpha_status = &alpha_status_vec[0];
813             const double* G = &G_vec[0];
814 
815             for( int i = 0; i < alpha_count; i++ )
816             {
817                 double t;
818 
819                 if( y[i] > 0 )    // y = +1
820                 {
821                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
822                     {
823                         Gmax1 = t;
824                         Gmax1_idx = i;
825                     }
826                     if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
827                     {
828                         Gmax2 = t;
829                         Gmax2_idx = i;
830                     }
831                 }
832                 else        // y = -1
833                 {
834                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
835                     {
836                         Gmax2 = t;
837                         Gmax2_idx = i;
838                     }
839                     if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
840                     {
841                         Gmax1 = t;
842                         Gmax1_idx = i;
843                     }
844                 }
845             }
846 
847             out_i = Gmax1_idx;
848             out_j = Gmax2_idx;
849 
850             return Gmax1 + Gmax2 < eps;
851         }
852 
calc_rho(double & rho,double & r)853         void calc_rho( double& rho, double& r )
854         {
855             int nr_free = 0;
856             double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
857             const schar* y = &y_vec[0];
858             const schar* alpha_status = &alpha_status_vec[0];
859             const double* G = &G_vec[0];
860 
861             for( int i = 0; i < alpha_count; i++ )
862             {
863                 double yG = y[i]*G[i];
864 
865                 if( is_lower_bound(i) )
866                 {
867                     if( y[i] > 0 )
868                         ub = MIN(ub,yG);
869                     else
870                         lb = MAX(lb,yG);
871                 }
872                 else if( is_upper_bound(i) )
873                 {
874                     if( y[i] < 0)
875                         ub = MIN(ub,yG);
876                     else
877                         lb = MAX(lb,yG);
878                 }
879                 else
880                 {
881                     ++nr_free;
882                     sum_free += yG;
883                 }
884             }
885 
886             rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
887             r = 0;
888         }
889 
select_working_set_nu_svm(int & out_i,int & out_j)890         bool select_working_set_nu_svm( int& out_i, int& out_j )
891         {
892             // return i,j which maximize -grad(f)^T d , under constraint
893             // if alpha_i == C, d != +1
894             // if alpha_i == 0, d != -1
895             double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
896             int Gmax1_idx = -1;
897 
898             double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
899             int Gmax2_idx = -1;
900 
901             double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
902             int Gmax3_idx = -1;
903 
904             double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
905             int Gmax4_idx = -1;
906 
907             const schar* y = &y_vec[0];
908             const schar* alpha_status = &alpha_status_vec[0];
909             const double* G = &G_vec[0];
910 
911             for( int i = 0; i < alpha_count; i++ )
912             {
913                 double t;
914 
915                 if( y[i] > 0 )    // y == +1
916                 {
917                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
918                     {
919                         Gmax1 = t;
920                         Gmax1_idx = i;
921                     }
922                     if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
923                     {
924                         Gmax2 = t;
925                         Gmax2_idx = i;
926                     }
927                 }
928                 else        // y == -1
929                 {
930                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
931                     {
932                         Gmax3 = t;
933                         Gmax3_idx = i;
934                     }
935                     if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
936                     {
937                         Gmax4 = t;
938                         Gmax4_idx = i;
939                     }
940                 }
941             }
942 
943             if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
944                 return 1;
945 
946             if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
947             {
948                 out_i = Gmax1_idx;
949                 out_j = Gmax2_idx;
950             }
951             else
952             {
953                 out_i = Gmax3_idx;
954                 out_j = Gmax4_idx;
955             }
956             return 0;
957         }
958 
calc_rho_nu_svm(double & rho,double & r)959         void calc_rho_nu_svm( double& rho, double& r )
960         {
961             int nr_free1 = 0, nr_free2 = 0;
962             double ub1 = DBL_MAX, ub2 = DBL_MAX;
963             double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
964             double sum_free1 = 0, sum_free2 = 0;
965 
966             const schar* y = &y_vec[0];
967             const schar* alpha_status = &alpha_status_vec[0];
968             const double* G = &G_vec[0];
969 
970             for( int i = 0; i < alpha_count; i++ )
971             {
972                 double G_i = G[i];
973                 if( y[i] > 0 )
974                 {
975                     if( is_lower_bound(i) )
976                         ub1 = MIN( ub1, G_i );
977                     else if( is_upper_bound(i) )
978                         lb1 = MAX( lb1, G_i );
979                     else
980                     {
981                         ++nr_free1;
982                         sum_free1 += G_i;
983                     }
984                 }
985                 else
986                 {
987                     if( is_lower_bound(i) )
988                         ub2 = MIN( ub2, G_i );
989                     else if( is_upper_bound(i) )
990                         lb2 = MAX( lb2, G_i );
991                     else
992                     {
993                         ++nr_free2;
994                         sum_free2 += G_i;
995                     }
996                 }
997             }
998 
999             double r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
1000             double r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
1001 
1002             rho = (r1 - r2)*0.5;
1003             r = (r1 + r2)*0.5;
1004         }
1005 
1006         /*
1007         ///////////////////////// construct and solve various formulations ///////////////////////
1008         */
solve_c_svc(const Mat & _samples,const vector<schar> & _y,double _Cp,double _Cn,const Ptr<SVM::Kernel> & _kernel,vector<double> & _alpha,SolutionInfo & _si,TermCriteria termCrit)1009         static bool solve_c_svc( const Mat& _samples, const vector<schar>& _y,
1010                                  double _Cp, double _Cn, const Ptr<SVM::Kernel>& _kernel,
1011                                  vector<double>& _alpha, SolutionInfo& _si, TermCriteria termCrit )
1012         {
1013             int sample_count = _samples.rows;
1014 
1015             _alpha.assign(sample_count, 0.);
1016             vector<double> _b(sample_count, -1.);
1017 
1018             Solver solver( _samples, _y, _alpha, _b, _Cp, _Cn, _kernel,
1019                            &Solver::get_row_svc,
1020                            &Solver::select_working_set,
1021                            &Solver::calc_rho,
1022                            termCrit );
1023 
1024             if( !solver.solve_generic( _si ))
1025                 return false;
1026 
1027             for( int i = 0; i < sample_count; i++ )
1028                 _alpha[i] *= _y[i];
1029 
1030             return true;
1031         }
1032 
1033 
solve_nu_svc(const Mat & _samples,const vector<schar> & _y,double nu,const Ptr<SVM::Kernel> & _kernel,vector<double> & _alpha,SolutionInfo & _si,TermCriteria termCrit)1034         static bool solve_nu_svc( const Mat& _samples, const vector<schar>& _y,
1035                                   double nu, const Ptr<SVM::Kernel>& _kernel,
1036                                   vector<double>& _alpha, SolutionInfo& _si,
1037                                   TermCriteria termCrit )
1038         {
1039             int sample_count = _samples.rows;
1040 
1041             _alpha.resize(sample_count);
1042             vector<double> _b(sample_count, 0.);
1043 
1044             double sum_pos = nu * sample_count * 0.5;
1045             double sum_neg = nu * sample_count * 0.5;
1046 
1047             for( int i = 0; i < sample_count; i++ )
1048             {
1049                 double a;
1050                 if( _y[i] > 0 )
1051                 {
1052                     a = std::min(1.0, sum_pos);
1053                     sum_pos -= a;
1054                 }
1055                 else
1056                 {
1057                     a = std::min(1.0, sum_neg);
1058                     sum_neg -= a;
1059                 }
1060                 _alpha[i] = a;
1061             }
1062 
1063             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
1064                            &Solver::get_row_svc,
1065                            &Solver::select_working_set_nu_svm,
1066                            &Solver::calc_rho_nu_svm,
1067                            termCrit );
1068 
1069             if( !solver.solve_generic( _si ))
1070                 return false;
1071 
1072             double inv_r = 1./_si.r;
1073 
1074             for( int i = 0; i < sample_count; i++ )
1075                 _alpha[i] *= _y[i]*inv_r;
1076 
1077             _si.rho *= inv_r;
1078             _si.obj *= (inv_r*inv_r);
1079             _si.upper_bound_p = inv_r;
1080             _si.upper_bound_n = inv_r;
1081 
1082             return true;
1083         }
1084 
solve_one_class(const Mat & _samples,double nu,const Ptr<SVM::Kernel> & _kernel,vector<double> & _alpha,SolutionInfo & _si,TermCriteria termCrit)1085         static bool solve_one_class( const Mat& _samples, double nu,
1086                                      const Ptr<SVM::Kernel>& _kernel,
1087                                      vector<double>& _alpha, SolutionInfo& _si,
1088                                      TermCriteria termCrit )
1089         {
1090             int sample_count = _samples.rows;
1091             vector<schar> _y(sample_count, 1);
1092             vector<double> _b(sample_count, 0.);
1093 
1094             int i, n = cvRound( nu*sample_count );
1095 
1096             _alpha.resize(sample_count);
1097             for( i = 0; i < sample_count; i++ )
1098                 _alpha[i] = i < n ? 1 : 0;
1099 
1100             if( n < sample_count )
1101                 _alpha[n] = nu * sample_count - n;
1102             else
1103                 _alpha[n-1] = nu * sample_count - (n-1);
1104 
1105             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
1106                            &Solver::get_row_one_class,
1107                            &Solver::select_working_set,
1108                            &Solver::calc_rho,
1109                            termCrit );
1110 
1111             return solver.solve_generic(_si);
1112         }
1113 
solve_eps_svr(const Mat & _samples,const vector<float> & _yf,double p,double C,const Ptr<SVM::Kernel> & _kernel,vector<double> & _alpha,SolutionInfo & _si,TermCriteria termCrit)1114         static bool solve_eps_svr( const Mat& _samples, const vector<float>& _yf,
1115                                    double p, double C, const Ptr<SVM::Kernel>& _kernel,
1116                                    vector<double>& _alpha, SolutionInfo& _si,
1117                                    TermCriteria termCrit )
1118         {
1119             int sample_count = _samples.rows;
1120             int alpha_count = sample_count*2;
1121 
1122             CV_Assert( (int)_yf.size() == sample_count );
1123 
1124             _alpha.assign(alpha_count, 0.);
1125             vector<schar> _y(alpha_count);
1126             vector<double> _b(alpha_count);
1127 
1128             for( int i = 0; i < sample_count; i++ )
1129             {
1130                 _b[i] = p - _yf[i];
1131                 _y[i] = 1;
1132 
1133                 _b[i+sample_count] = p + _yf[i];
1134                 _y[i+sample_count] = -1;
1135             }
1136 
1137             Solver solver( _samples, _y, _alpha, _b, C, C, _kernel,
1138                            &Solver::get_row_svr,
1139                            &Solver::select_working_set,
1140                            &Solver::calc_rho,
1141                            termCrit );
1142 
1143             if( !solver.solve_generic( _si ))
1144                 return false;
1145 
1146             for( int i = 0; i < sample_count; i++ )
1147                 _alpha[i] -= _alpha[i+sample_count];
1148 
1149             return true;
1150         }
1151 
1152 
solve_nu_svr(const Mat & _samples,const vector<float> & _yf,double nu,double C,const Ptr<SVM::Kernel> & _kernel,vector<double> & _alpha,SolutionInfo & _si,TermCriteria termCrit)1153         static bool solve_nu_svr( const Mat& _samples, const vector<float>& _yf,
1154                                   double nu, double C, const Ptr<SVM::Kernel>& _kernel,
1155                                   vector<double>& _alpha, SolutionInfo& _si,
1156                                   TermCriteria termCrit )
1157         {
1158             int sample_count = _samples.rows;
1159             int alpha_count = sample_count*2;
1160             double sum = C * nu * sample_count * 0.5;
1161 
1162             CV_Assert( (int)_yf.size() == sample_count );
1163 
1164             _alpha.resize(alpha_count);
1165             vector<schar> _y(alpha_count);
1166             vector<double> _b(alpha_count);
1167 
1168             for( int i = 0; i < sample_count; i++ )
1169             {
1170                 _alpha[i] = _alpha[i + sample_count] = std::min(sum, C);
1171                 sum -= _alpha[i];
1172 
1173                 _b[i] = -_yf[i];
1174                 _y[i] = 1;
1175 
1176                 _b[i + sample_count] = _yf[i];
1177                 _y[i + sample_count] = -1;
1178             }
1179 
1180             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
1181                            &Solver::get_row_svr,
1182                            &Solver::select_working_set_nu_svm,
1183                            &Solver::calc_rho_nu_svm,
1184                            termCrit );
1185 
1186             if( !solver.solve_generic( _si ))
1187                 return false;
1188 
1189             for( int i = 0; i < sample_count; i++ )
1190                 _alpha[i] -= _alpha[i+sample_count];
1191 
1192             return true;
1193         }
1194 
1195         int sample_count;
1196         int var_count;
1197         int cache_size;
1198         int max_cache_size;
1199         Mat samples;
1200         SvmParams params;
1201         vector<KernelRow> lru_cache;
1202         int lru_first;
1203         int lru_last;
1204         Mat lru_cache_data;
1205 
1206         int alpha_count;
1207 
1208         vector<double> G_vec;
1209         vector<double>* alpha_vec;
1210         vector<schar> y_vec;
1211         // -1 - lower bound, 0 - free, 1 - upper bound
1212         vector<schar> alpha_status_vec;
1213         vector<double> b_vec;
1214 
1215         vector<Qfloat> buf[2];
1216         double eps;
1217         int max_iter;
1218         double C[2];  // C[0] == Cn, C[1] == Cp
1219         Ptr<SVM::Kernel> kernel;
1220 
1221         SelectWorkingSet select_working_set_func;
1222         CalcRho calc_rho_func;
1223         GetRow get_row_func;
1224     };
1225 
1226     //////////////////////////////////////////////////////////////////////////////////////////
SVMImpl()1227     SVMImpl()
1228     {
1229         clear();
1230         checkParams();
1231     }
1232 
~SVMImpl()1233     ~SVMImpl()
1234     {
1235         clear();
1236     }
1237 
clear()1238     void clear()
1239     {
1240         decision_func.clear();
1241         df_alpha.clear();
1242         df_index.clear();
1243         sv.release();
1244     }
1245 
getSupportVectors() const1246     Mat getSupportVectors() const
1247     {
1248         return sv;
1249     }
1250 
1251     CV_IMPL_PROPERTY(int, Type, params.svmType)
1252     CV_IMPL_PROPERTY(double, Gamma, params.gamma)
1253     CV_IMPL_PROPERTY(double, Coef0, params.coef0)
1254     CV_IMPL_PROPERTY(double, Degree, params.degree)
1255     CV_IMPL_PROPERTY(double, C, params.C)
1256     CV_IMPL_PROPERTY(double, Nu, params.nu)
1257     CV_IMPL_PROPERTY(double, P, params.p)
1258     CV_IMPL_PROPERTY_S(cv::Mat, ClassWeights, params.classWeights)
1259     CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
1260 
getKernelType() const1261     int getKernelType() const
1262     {
1263         return params.kernelType;
1264     }
1265 
setKernel(int kernelType)1266     void setKernel(int kernelType)
1267     {
1268         params.kernelType = kernelType;
1269         if (kernelType != CUSTOM)
1270             kernel = makePtr<SVMKernelImpl>(params);
1271     }
1272 
setCustomKernel(const Ptr<Kernel> & _kernel)1273     void setCustomKernel(const Ptr<Kernel> &_kernel)
1274     {
1275         params.kernelType = CUSTOM;
1276         kernel = _kernel;
1277     }
1278 
checkParams()1279     void checkParams()
1280     {
1281         int kernelType = params.kernelType;
1282         if (kernelType != CUSTOM)
1283         {
1284             if( kernelType != LINEAR && kernelType != POLY &&
1285                 kernelType != SIGMOID && kernelType != RBF &&
1286                 kernelType != INTER && kernelType != CHI2)
1287                 CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );
1288 
1289             if( kernelType == LINEAR )
1290                 params.gamma = 1;
1291             else if( params.gamma <= 0 )
1292                 CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1293 
1294             if( kernelType != SIGMOID && kernelType != POLY )
1295                 params.coef0 = 0;
1296             else if( params.coef0 < 0 )
1297                 CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1298 
1299             if( kernelType != POLY )
1300                 params.degree = 0;
1301             else if( params.degree <= 0 )
1302                 CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1303 
1304             kernel = makePtr<SVMKernelImpl>(params);
1305         }
1306         else
1307         {
1308             if (!kernel)
1309                 CV_Error( CV_StsBadArg, "Custom kernel is not set" );
1310         }
1311 
1312         int svmType = params.svmType;
1313 
1314         if( svmType != C_SVC && svmType != NU_SVC &&
1315             svmType != ONE_CLASS && svmType != EPS_SVR &&
1316             svmType != NU_SVR )
1317             CV_Error( CV_StsBadArg, "Unknown/unsupported SVM type" );
1318 
1319         if( svmType == ONE_CLASS || svmType == NU_SVC )
1320             params.C = 0;
1321         else if( params.C <= 0 )
1322             CV_Error( CV_StsOutOfRange, "The parameter C must be positive" );
1323 
1324         if( svmType == C_SVC || svmType == EPS_SVR )
1325             params.nu = 0;
1326         else if( params.nu <= 0 || params.nu >= 1 )
1327             CV_Error( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1328 
1329         if( svmType != EPS_SVR )
1330             params.p = 0;
1331         else if( params.p <= 0 )
1332             CV_Error( CV_StsOutOfRange, "The parameter p must be positive" );
1333 
1334         if( svmType != C_SVC )
1335             params.classWeights.release();
1336 
1337         if( !(params.termCrit.type & TermCriteria::EPS) )
1338             params.termCrit.epsilon = DBL_EPSILON;
1339         params.termCrit.epsilon = std::max(params.termCrit.epsilon, DBL_EPSILON);
1340         if( !(params.termCrit.type & TermCriteria::COUNT) )
1341             params.termCrit.maxCount = INT_MAX;
1342         params.termCrit.maxCount = std::max(params.termCrit.maxCount, 1);
1343     }
1344 
setParams(const SvmParams & _params)1345     void setParams( const SvmParams& _params)
1346     {
1347         params = _params;
1348         checkParams();
1349     }
1350 
getSVCount(int i) const1351     int getSVCount(int i) const
1352     {
1353         return (i < (int)(decision_func.size()-1) ? decision_func[i+1].ofs :
1354                 (int)df_index.size()) - decision_func[i].ofs;
1355     }
1356 
do_train(const Mat & _samples,const Mat & _responses)1357     bool do_train( const Mat& _samples, const Mat& _responses )
1358     {
1359         int svmType = params.svmType;
1360         int i, j, k, sample_count = _samples.rows;
1361         vector<double> _alpha;
1362         Solver::SolutionInfo sinfo;
1363 
1364         CV_Assert( _samples.type() == CV_32F );
1365         var_count = _samples.cols;
1366 
1367         if( svmType == ONE_CLASS || svmType == EPS_SVR || svmType == NU_SVR )
1368         {
1369             int sv_count = 0;
1370             decision_func.clear();
1371 
1372             vector<float> _yf;
1373             if( !_responses.empty() )
1374                 _responses.convertTo(_yf, CV_32F);
1375 
1376             bool ok =
1377             svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, params.termCrit ) :
1378             svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, params.termCrit ) :
1379             svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, params.termCrit ) : false;
1380 
1381             if( !ok )
1382                 return false;
1383 
1384             for( i = 0; i < sample_count; i++ )
1385                 sv_count += fabs(_alpha[i]) > 0;
1386 
1387             CV_Assert(sv_count != 0);
1388 
1389             sv.create(sv_count, _samples.cols, CV_32F);
1390             df_alpha.resize(sv_count);
1391             df_index.resize(sv_count);
1392 
1393             for( i = k = 0; i < sample_count; i++ )
1394             {
1395                 if( std::abs(_alpha[i]) > 0 )
1396                 {
1397                     _samples.row(i).copyTo(sv.row(k));
1398                     df_alpha[k] = _alpha[i];
1399                     df_index[k] = k;
1400                     k++;
1401                 }
1402             }
1403 
1404             decision_func.push_back(DecisionFunc(sinfo.rho, 0));
1405         }
1406         else
1407         {
1408             int class_count = (int)class_labels.total();
1409             vector<int> svidx, sidx, sidx_all, sv_tab(sample_count, 0);
1410             Mat temp_samples, class_weights;
1411             vector<int> class_ranges;
1412             vector<schar> temp_y;
1413             double nu = params.nu;
1414             CV_Assert( svmType == C_SVC || svmType == NU_SVC );
1415 
1416             if( svmType == C_SVC && !params.classWeights.empty() )
1417             {
1418                 const Mat cw = params.classWeights;
1419 
1420                 if( (cw.cols != 1 && cw.rows != 1) ||
1421                     (int)cw.total() != class_count ||
1422                     (cw.type() != CV_32F && cw.type() != CV_64F) )
1423                     CV_Error( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1424                         "containing as many elements as the number of classes" );
1425 
1426                 cw.convertTo(class_weights, CV_64F, params.C);
1427                 //normalize(cw, class_weights, params.C, 0, NORM_L1, CV_64F);
1428             }
1429 
1430             decision_func.clear();
1431             df_alpha.clear();
1432             df_index.clear();
1433 
1434             sortSamplesByClasses( _samples, _responses, sidx_all, class_ranges );
1435 
1436             //check that while cross-validation there were the samples from all the classes
1437             if( class_ranges[class_count] <= 0 )
1438                 CV_Error( CV_StsBadArg, "While cross-validation one or more of the classes have "
1439                 "been fell out of the sample. Try to enlarge <Params::k_fold>" );
1440 
1441             if( svmType == NU_SVC )
1442             {
1443                 // check if nu is feasible
1444                 for( i = 0; i < class_count; i++ )
1445                 {
1446                     int ci = class_ranges[i+1] - class_ranges[i];
1447                     for( j = i+1; j< class_count; j++ )
1448                     {
1449                         int cj = class_ranges[j+1] - class_ranges[j];
1450                         if( nu*(ci + cj)*0.5 > std::min( ci, cj ) )
1451                             // TODO: add some diagnostic
1452                             return false;
1453                     }
1454                 }
1455             }
1456 
1457             size_t samplesize = _samples.cols*_samples.elemSize();
1458 
1459             // train n*(n-1)/2 classifiers
1460             for( i = 0; i < class_count; i++ )
1461             {
1462                 for( j = i+1; j < class_count; j++ )
1463                 {
1464                     int si = class_ranges[i], ci = class_ranges[i+1] - si;
1465                     int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1466                     double Cp = params.C, Cn = Cp;
1467 
1468                     temp_samples.create(ci + cj, _samples.cols, _samples.type());
1469                     sidx.resize(ci + cj);
1470                     temp_y.resize(ci + cj);
1471 
1472                     // form input for the binary classification problem
1473                     for( k = 0; k < ci+cj; k++ )
1474                     {
1475                         int idx = k < ci ? si+k : sj+k-ci;
1476                         memcpy(temp_samples.ptr(k), _samples.ptr(sidx_all[idx]), samplesize);
1477                         sidx[k] = sidx_all[idx];
1478                         temp_y[k] = k < ci ? 1 : -1;
1479                     }
1480 
1481                     if( !class_weights.empty() )
1482                     {
1483                         Cp = class_weights.at<double>(i);
1484                         Cn = class_weights.at<double>(j);
1485                     }
1486 
1487                     DecisionFunc df;
1488                     bool ok = params.svmType == C_SVC ?
1489                                 Solver::solve_c_svc( temp_samples, temp_y, Cp, Cn,
1490                                                      kernel, _alpha, sinfo, params.termCrit ) :
1491                               params.svmType == NU_SVC ?
1492                                 Solver::solve_nu_svc( temp_samples, temp_y, params.nu,
1493                                                       kernel, _alpha, sinfo, params.termCrit ) :
1494                               false;
1495                     if( !ok )
1496                         return false;
1497                     df.rho = sinfo.rho;
1498                     df.ofs = (int)df_index.size();
1499                     decision_func.push_back(df);
1500 
1501                     for( k = 0; k < ci + cj; k++ )
1502                     {
1503                         if( std::abs(_alpha[k]) > 0 )
1504                         {
1505                             int idx = k < ci ? si+k : sj+k-ci;
1506                             sv_tab[sidx_all[idx]] = 1;
1507                             df_index.push_back(sidx_all[idx]);
1508                             df_alpha.push_back(_alpha[k]);
1509                         }
1510                     }
1511                 }
1512             }
1513 
1514             // allocate support vectors and initialize sv_tab
1515             for( i = 0, k = 0; i < sample_count; i++ )
1516             {
1517                 if( sv_tab[i] )
1518                     sv_tab[i] = ++k;
1519             }
1520 
1521             int sv_total = k;
1522             sv.create(sv_total, _samples.cols, _samples.type());
1523 
1524             for( i = 0; i < sample_count; i++ )
1525             {
1526                 if( !sv_tab[i] )
1527                     continue;
1528                 memcpy(sv.ptr(sv_tab[i]-1), _samples.ptr(i), samplesize);
1529             }
1530 
1531             // set sv pointers
1532             int n = (int)df_index.size();
1533             for( i = 0; i < n; i++ )
1534             {
1535                 CV_Assert( sv_tab[df_index[i]] > 0 );
1536                 df_index[i] = sv_tab[df_index[i]] - 1;
1537             }
1538         }
1539 
1540         optimize_linear_svm();
1541         return true;
1542     }
1543 
optimize_linear_svm()1544     void optimize_linear_svm()
1545     {
1546         // we optimize only linear SVM: compress all the support vectors into one.
1547         if( params.kernelType != LINEAR )
1548             return;
1549 
1550         int i, df_count = (int)decision_func.size();
1551 
1552         for( i = 0; i < df_count; i++ )
1553         {
1554             if( getSVCount(i) != 1 )
1555                 break;
1556         }
1557 
1558         // if every decision functions uses a single support vector;
1559         // it's already compressed. skip it then.
1560         if( i == df_count )
1561             return;
1562 
1563         AutoBuffer<double> vbuf(var_count);
1564         double* v = vbuf;
1565         Mat new_sv(df_count, var_count, CV_32F);
1566 
1567         vector<DecisionFunc> new_df;
1568 
1569         for( i = 0; i < df_count; i++ )
1570         {
1571             float* dst = new_sv.ptr<float>(i);
1572             memset(v, 0, var_count*sizeof(v[0]));
1573             int j, k, sv_count = getSVCount(i);
1574             const DecisionFunc& df = decision_func[i];
1575             const int* sv_index = &df_index[df.ofs];
1576             const double* sv_alpha = &df_alpha[df.ofs];
1577             for( j = 0; j < sv_count; j++ )
1578             {
1579                 const float* src = sv.ptr<float>(sv_index[j]);
1580                 double a = sv_alpha[j];
1581                 for( k = 0; k < var_count; k++ )
1582                     v[k] += src[k]*a;
1583             }
1584             for( k = 0; k < var_count; k++ )
1585                 dst[k] = (float)v[k];
1586             new_df.push_back(DecisionFunc(df.rho, i));
1587         }
1588 
1589         setRangeVector(df_index, df_count);
1590         df_alpha.assign(df_count, 1.);
1591         std::swap(sv, new_sv);
1592         std::swap(decision_func, new_df);
1593     }
1594 
train(const Ptr<TrainData> & data,int)1595     bool train( const Ptr<TrainData>& data, int )
1596     {
1597         clear();
1598 
1599         checkParams();
1600 
1601         int svmType = params.svmType;
1602         Mat samples = data->getTrainSamples();
1603         Mat responses;
1604 
1605         if( svmType == C_SVC || svmType == NU_SVC )
1606         {
1607             responses = data->getTrainNormCatResponses();
1608             if( responses.empty() )
1609                 CV_Error(CV_StsBadArg, "in the case of classification problem the responses must be categorical; "
1610                                        "either specify varType when creating TrainData, or pass integer responses");
1611             class_labels = data->getClassLabels();
1612         }
1613         else
1614             responses = data->getTrainResponses();
1615 
1616         if( !do_train( samples, responses ))
1617         {
1618             clear();
1619             return false;
1620         }
1621 
1622         return true;
1623     }
1624 
trainAuto(const Ptr<TrainData> & data,int k_fold,ParamGrid C_grid,ParamGrid gamma_grid,ParamGrid p_grid,ParamGrid nu_grid,ParamGrid coef_grid,ParamGrid degree_grid,bool balanced)1625     bool trainAuto( const Ptr<TrainData>& data, int k_fold,
1626                     ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
1627                     ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
1628                     bool balanced )
1629     {
1630         checkParams();
1631 
1632         int svmType = params.svmType;
1633         RNG rng((uint64)-1);
1634 
1635         if( svmType == ONE_CLASS )
1636             // current implementation of "auto" svm does not support the 1-class case.
1637             return train( data, 0 );
1638 
1639         clear();
1640 
1641         CV_Assert( k_fold >= 2 );
1642 
1643         // All the parameters except, possibly, <coef0> are positive.
1644         // <coef0> is nonnegative
1645         #define CHECK_GRID(grid, param) \
1646         if( grid.logStep <= 1 ) \
1647         { \
1648             grid.minVal = grid.maxVal = params.param; \
1649             grid.logStep = 10; \
1650         } \
1651         else \
1652             checkParamGrid(grid)
1653 
1654         CHECK_GRID(C_grid, C);
1655         CHECK_GRID(gamma_grid, gamma);
1656         CHECK_GRID(p_grid, p);
1657         CHECK_GRID(nu_grid, nu);
1658         CHECK_GRID(coef_grid, coef0);
1659         CHECK_GRID(degree_grid, degree);
1660 
1661         // these parameters are not used:
1662         if( params.kernelType != POLY )
1663             degree_grid.minVal = degree_grid.maxVal = params.degree;
1664         if( params.kernelType == LINEAR )
1665             gamma_grid.minVal = gamma_grid.maxVal = params.gamma;
1666         if( params.kernelType != POLY && params.kernelType != SIGMOID )
1667             coef_grid.minVal = coef_grid.maxVal = params.coef0;
1668         if( svmType == NU_SVC || svmType == ONE_CLASS )
1669             C_grid.minVal = C_grid.maxVal = params.C;
1670         if( svmType == C_SVC || svmType == EPS_SVR )
1671             nu_grid.minVal = nu_grid.maxVal = params.nu;
1672         if( svmType != EPS_SVR )
1673             p_grid.minVal = p_grid.maxVal = params.p;
1674 
1675         Mat samples = data->getTrainSamples();
1676         Mat responses;
1677         bool is_classification = false;
1678         int class_count = (int)class_labels.total();
1679 
1680         if( svmType == C_SVC || svmType == NU_SVC )
1681         {
1682             responses = data->getTrainNormCatResponses();
1683             class_labels = data->getClassLabels();
1684             class_count = (int)class_labels.total();
1685             is_classification = true;
1686 
1687             vector<int> temp_class_labels;
1688             setRangeVector(temp_class_labels, class_count);
1689 
1690             // temporarily replace class labels with 0, 1, ..., NCLASSES-1
1691             Mat(temp_class_labels).copyTo(class_labels);
1692         }
1693         else
1694             responses = data->getTrainResponses();
1695 
1696         CV_Assert(samples.type() == CV_32F);
1697 
1698         int sample_count = samples.rows;
1699         var_count = samples.cols;
1700         size_t sample_size = var_count*samples.elemSize();
1701 
1702         vector<int> sidx;
1703         setRangeVector(sidx, sample_count);
1704 
1705         int i, j, k;
1706 
1707         // randomly permute training samples
1708         for( i = 0; i < sample_count; i++ )
1709         {
1710             int i1 = rng.uniform(0, sample_count);
1711             int i2 = rng.uniform(0, sample_count);
1712             std::swap(sidx[i1], sidx[i2]);
1713         }
1714 
1715         if( is_classification && class_count == 2 && balanced )
1716         {
1717             // reshuffle the training set in such a way that
1718             // instances of each class are divided more or less evenly
1719             // between the k_fold parts.
1720             vector<int> sidx0, sidx1;
1721 
1722             for( i = 0; i < sample_count; i++ )
1723             {
1724                 if( responses.at<int>(sidx[i]) == 0 )
1725                     sidx0.push_back(sidx[i]);
1726                 else
1727                     sidx1.push_back(sidx[i]);
1728             }
1729 
1730             int n0 = (int)sidx0.size(), n1 = (int)sidx1.size();
1731             int a0 = 0, a1 = 0;
1732             sidx.clear();
1733             for( k = 0; k < k_fold; k++ )
1734             {
1735                 int b0 = ((k+1)*n0 + k_fold/2)/k_fold, b1 = ((k+1)*n1 + k_fold/2)/k_fold;
1736                 int a = (int)sidx.size(), b = a + (b0 - a0) + (b1 - a1);
1737                 for( i = a0; i < b0; i++ )
1738                     sidx.push_back(sidx0[i]);
1739                 for( i = a1; i < b1; i++ )
1740                     sidx.push_back(sidx1[i]);
1741                 for( i = 0; i < (b - a); i++ )
1742                 {
1743                     int i1 = rng.uniform(a, b);
1744                     int i2 = rng.uniform(a, b);
1745                     std::swap(sidx[i1], sidx[i2]);
1746                 }
1747                 a0 = b0; a1 = b1;
1748             }
1749         }
1750 
1751         int test_sample_count = (sample_count + k_fold/2)/k_fold;
1752         int train_sample_count = sample_count - test_sample_count;
1753 
1754         SvmParams best_params = params;
1755         double min_error = FLT_MAX;
1756 
1757         int rtype = responses.type();
1758 
1759         Mat temp_train_samples(train_sample_count, var_count, CV_32F);
1760         Mat temp_test_samples(test_sample_count, var_count, CV_32F);
1761         Mat temp_train_responses(train_sample_count, 1, rtype);
1762         Mat temp_test_responses;
1763 
1764         // If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
1765         #define FOR_IN_GRID(var, grid) \
1766             for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )
1767 
1768         FOR_IN_GRID(C, C_grid)
1769         FOR_IN_GRID(gamma, gamma_grid)
1770         FOR_IN_GRID(p, p_grid)
1771         FOR_IN_GRID(nu, nu_grid)
1772         FOR_IN_GRID(coef0, coef_grid)
1773         FOR_IN_GRID(degree, degree_grid)
1774         {
1775             // make sure we updated the kernel and other parameters
1776             setParams(params);
1777 
1778             double error = 0;
1779             for( k = 0; k < k_fold; k++ )
1780             {
1781                 int start = (k*sample_count + k_fold/2)/k_fold;
1782                 for( i = 0; i < train_sample_count; i++ )
1783                 {
1784                     j = sidx[(i+start)%sample_count];
1785                     memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
1786                     if( is_classification )
1787                         temp_train_responses.at<int>(i) = responses.at<int>(j);
1788                     else if( !responses.empty() )
1789                         temp_train_responses.at<float>(i) = responses.at<float>(j);
1790                 }
1791 
1792                 // Train SVM on <train_size> samples
1793                 if( !do_train( temp_train_samples, temp_train_responses ))
1794                     continue;
1795 
1796                 for( i = 0; i < train_sample_count; i++ )
1797                 {
1798                     j = sidx[(i+start+train_sample_count) % sample_count];
1799                     memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
1800                 }
1801 
1802                 predict(temp_test_samples, temp_test_responses, 0);
1803                 for( i = 0; i < test_sample_count; i++ )
1804                 {
1805                     float val = temp_test_responses.at<float>(i);
1806                     j = sidx[(i+start+train_sample_count) % sample_count];
1807                     if( is_classification )
1808                         error += (float)(val != responses.at<int>(j));
1809                     else
1810                     {
1811                         val -= responses.at<float>(j);
1812                         error += val*val;
1813                     }
1814                 }
1815             }
1816             if( min_error > error )
1817             {
1818                 min_error   = error;
1819                 best_params = params;
1820             }
1821         }
1822 
1823         params = best_params;
1824         return do_train( samples, responses );
1825     }
1826 
1827     struct PredictBody : ParallelLoopBody
1828     {
PredictBodycv::ml::SVMImpl::PredictBody1829         PredictBody( const SVMImpl* _svm, const Mat& _samples, Mat& _results, bool _returnDFVal )
1830         {
1831             svm = _svm;
1832             results = &_results;
1833             samples = &_samples;
1834             returnDFVal = _returnDFVal;
1835         }
1836 
operator ()cv::ml::SVMImpl::PredictBody1837         void operator()( const Range& range ) const
1838         {
1839             int svmType = svm->params.svmType;
1840             int sv_total = svm->sv.rows;
1841             int class_count = !svm->class_labels.empty() ? (int)svm->class_labels.total() : svmType == ONE_CLASS ? 1 : 0;
1842 
1843             AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
1844             float* buffer = _buffer;
1845 
1846             int i, j, dfi, k, si;
1847 
1848             if( svmType == EPS_SVR || svmType == NU_SVR || svmType == ONE_CLASS )
1849             {
1850                 for( si = range.start; si < range.end; si++ )
1851                 {
1852                     const float* row_sample = samples->ptr<float>(si);
1853                     svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(), row_sample, buffer );
1854 
1855                     const SVMImpl::DecisionFunc* df = &svm->decision_func[0];
1856                     double sum = -df->rho;
1857                     for( i = 0; i < sv_total; i++ )
1858                         sum += buffer[i]*svm->df_alpha[i];
1859                     float result = svm->params.svmType == ONE_CLASS && !returnDFVal ? (float)(sum > 0) : (float)sum;
1860                     results->at<float>(si) = result;
1861                 }
1862             }
1863             else if( svmType == C_SVC || svmType == NU_SVC )
1864             {
1865                 int* vote = (int*)(buffer + sv_total);
1866 
1867                 for( si = range.start; si < range.end; si++ )
1868                 {
1869                     svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(),
1870                                        samples->ptr<float>(si), buffer );
1871                     double sum = 0.;
1872 
1873                     memset( vote, 0, class_count*sizeof(vote[0]));
1874 
1875                     for( i = dfi = 0; i < class_count; i++ )
1876                     {
1877                         for( j = i+1; j < class_count; j++, dfi++ )
1878                         {
1879                             const DecisionFunc& df = svm->decision_func[dfi];
1880                             sum = -df.rho;
1881                             int sv_count = svm->getSVCount(dfi);
1882                             const double* alpha = &svm->df_alpha[df.ofs];
1883                             const int* sv_index = &svm->df_index[df.ofs];
1884                             for( k = 0; k < sv_count; k++ )
1885                                 sum += alpha[k]*buffer[sv_index[k]];
1886 
1887                             vote[sum > 0 ? i : j]++;
1888                         }
1889                     }
1890 
1891                     for( i = 1, k = 0; i < class_count; i++ )
1892                     {
1893                         if( vote[i] > vote[k] )
1894                             k = i;
1895                     }
1896                     float result = returnDFVal && class_count == 2 ?
1897                         (float)sum : (float)(svm->class_labels.at<int>(k));
1898                     results->at<float>(si) = result;
1899                 }
1900             }
1901             else
1902                 CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1903                          "the SVM structure is probably corrupted" );
1904         }
1905 
1906         const SVMImpl* svm;
1907         const Mat* samples;
1908         Mat* results;
1909         bool returnDFVal;
1910     };
1911 
predict(InputArray _samples,OutputArray _results,int flags) const1912     float predict( InputArray _samples, OutputArray _results, int flags ) const
1913     {
1914         float result = 0;
1915         Mat samples = _samples.getMat(), results;
1916         int nsamples = samples.rows;
1917         bool returnDFVal = (flags & RAW_OUTPUT) != 0;
1918 
1919         CV_Assert( samples.cols == var_count && samples.type() == CV_32F );
1920 
1921         if( _results.needed() )
1922         {
1923             _results.create( nsamples, 1, samples.type() );
1924             results = _results.getMat();
1925         }
1926         else
1927         {
1928             CV_Assert( nsamples == 1 );
1929             results = Mat(1, 1, CV_32F, &result);
1930         }
1931 
1932         PredictBody invoker(this, samples, results, returnDFVal);
1933         if( nsamples < 10 )
1934             invoker(Range(0, nsamples));
1935         else
1936             parallel_for_(Range(0, nsamples), invoker);
1937         return result;
1938     }
1939 
getDecisionFunction(int i,OutputArray _alpha,OutputArray _svidx) const1940     double getDecisionFunction(int i, OutputArray _alpha, OutputArray _svidx ) const
1941     {
1942         CV_Assert( 0 <= i && i < (int)decision_func.size());
1943         const DecisionFunc& df = decision_func[i];
1944         int count = getSVCount(i);
1945         Mat(1, count, CV_64F, (double*)&df_alpha[df.ofs]).copyTo(_alpha);
1946         Mat(1, count, CV_32S, (int*)&df_index[df.ofs]).copyTo(_svidx);
1947         return df.rho;
1948     }
1949 
write_params(FileStorage & fs) const1950     void write_params( FileStorage& fs ) const
1951     {
1952         int svmType = params.svmType;
1953         int kernelType = params.kernelType;
1954 
1955         String svm_type_str =
1956             svmType == C_SVC ? "C_SVC" :
1957             svmType == NU_SVC ? "NU_SVC" :
1958             svmType == ONE_CLASS ? "ONE_CLASS" :
1959             svmType == EPS_SVR ? "EPS_SVR" :
1960             svmType == NU_SVR ? "NU_SVR" : format("Uknown_%d", svmType);
1961         String kernel_type_str =
1962             kernelType == LINEAR ? "LINEAR" :
1963             kernelType == POLY ? "POLY" :
1964             kernelType == RBF ? "RBF" :
1965             kernelType == SIGMOID ? "SIGMOID" :
1966             kernelType == CHI2 ? "CHI2" :
1967             kernelType == INTER ? "INTER" : format("Unknown_%d", kernelType);
1968 
1969         fs << "svmType" << svm_type_str;
1970 
1971         // save kernel
1972         fs << "kernel" << "{" << "type" << kernel_type_str;
1973 
1974         if( kernelType == POLY )
1975             fs << "degree" << params.degree;
1976 
1977         if( kernelType != LINEAR )
1978             fs << "gamma" << params.gamma;
1979 
1980         if( kernelType == POLY || kernelType == SIGMOID )
1981             fs << "coef0" << params.coef0;
1982 
1983         fs << "}";
1984 
1985         if( svmType == C_SVC || svmType == EPS_SVR || svmType == NU_SVR )
1986             fs << "C" << params.C;
1987 
1988         if( svmType == NU_SVC || svmType == ONE_CLASS || svmType == NU_SVR )
1989             fs << "nu" << params.nu;
1990 
1991         if( svmType == EPS_SVR )
1992             fs << "p" << params.p;
1993 
1994         fs << "term_criteria" << "{:";
1995         if( params.termCrit.type & TermCriteria::EPS )
1996             fs << "epsilon" << params.termCrit.epsilon;
1997         if( params.termCrit.type & TermCriteria::COUNT )
1998             fs << "iterations" << params.termCrit.maxCount;
1999         fs << "}";
2000     }
2001 
isTrained() const2002     bool isTrained() const
2003     {
2004         return !sv.empty();
2005     }
2006 
isClassifier() const2007     bool isClassifier() const
2008     {
2009         return params.svmType == C_SVC || params.svmType == NU_SVC || params.svmType == ONE_CLASS;
2010     }
2011 
getVarCount() const2012     int getVarCount() const
2013     {
2014         return var_count;
2015     }
2016 
getDefaultName() const2017     String getDefaultName() const
2018     {
2019         return "opencv_ml_svm";
2020     }
2021 
write(FileStorage & fs) const2022     void write( FileStorage& fs ) const
2023     {
2024         int class_count = !class_labels.empty() ? (int)class_labels.total() :
2025                           params.svmType == ONE_CLASS ? 1 : 0;
2026         if( !isTrained() )
2027             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2028 
2029         write_params( fs );
2030 
2031         fs << "var_count" << var_count;
2032 
2033         if( class_count > 0 )
2034         {
2035             fs << "class_count" << class_count;
2036 
2037             if( !class_labels.empty() )
2038                 fs << "class_labels" << class_labels;
2039 
2040             if( !params.classWeights.empty() )
2041                 fs << "class_weights" << params.classWeights;
2042         }
2043 
2044         // write the joint collection of support vectors
2045         int i, sv_total = sv.rows;
2046         fs << "sv_total" << sv_total;
2047         fs << "support_vectors" << "[";
2048         for( i = 0; i < sv_total; i++ )
2049         {
2050             fs << "[:";
2051             fs.writeRaw("f", sv.ptr(i), sv.cols*sv.elemSize());
2052             fs << "]";
2053         }
2054         fs << "]";
2055 
2056         // write decision functions
2057         int df_count = (int)decision_func.size();
2058 
2059         fs << "decision_functions" << "[";
2060         for( i = 0; i < df_count; i++ )
2061         {
2062             const DecisionFunc& df = decision_func[i];
2063             int sv_count = getSVCount(i);
2064             fs << "{" << "sv_count" << sv_count
2065                << "rho" << df.rho
2066                << "alpha" << "[:";
2067             fs.writeRaw("d", (const uchar*)&df_alpha[df.ofs], sv_count*sizeof(df_alpha[0]));
2068             fs << "]";
2069             if( class_count > 2 )
2070             {
2071                 fs << "index" << "[:";
2072                 fs.writeRaw("i", (const uchar*)&df_index[df.ofs], sv_count*sizeof(df_index[0]));
2073                 fs << "]";
2074             }
2075             else
2076                 CV_Assert( sv_count == sv_total );
2077             fs << "}";
2078         }
2079         fs << "]";
2080     }
2081 
read_params(const FileNode & fn)2082     void read_params( const FileNode& fn )
2083     {
2084         SvmParams _params;
2085 
2086         // check for old naming
2087         String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
2088         int svmType =
2089             svm_type_str == "C_SVC" ? C_SVC :
2090             svm_type_str == "NU_SVC" ? NU_SVC :
2091             svm_type_str == "ONE_CLASS" ? ONE_CLASS :
2092             svm_type_str == "EPS_SVR" ? EPS_SVR :
2093             svm_type_str == "NU_SVR" ? NU_SVR : -1;
2094 
2095         if( svmType < 0 )
2096             CV_Error( CV_StsParseError, "Missing of invalid SVM type" );
2097 
2098         FileNode kernel_node = fn["kernel"];
2099         if( kernel_node.empty() )
2100             CV_Error( CV_StsParseError, "SVM kernel tag is not found" );
2101 
2102         String kernel_type_str = (String)kernel_node["type"];
2103         int kernelType =
2104             kernel_type_str == "LINEAR" ? LINEAR :
2105             kernel_type_str == "POLY" ? POLY :
2106             kernel_type_str == "RBF" ? RBF :
2107             kernel_type_str == "SIGMOID" ? SIGMOID :
2108             kernel_type_str == "CHI2" ? CHI2 :
2109             kernel_type_str == "INTER" ? INTER : CUSTOM;
2110 
2111         if( kernelType == CUSTOM )
2112             CV_Error( CV_StsParseError, "Invalid SVM kernel type (or custom kernel)" );
2113 
2114         _params.svmType = svmType;
2115         _params.kernelType = kernelType;
2116         _params.degree = (double)kernel_node["degree"];
2117         _params.gamma = (double)kernel_node["gamma"];
2118         _params.coef0 = (double)kernel_node["coef0"];
2119 
2120         _params.C = (double)fn["C"];
2121         _params.nu = (double)fn["nu"];
2122         _params.p = (double)fn["p"];
2123         _params.classWeights = Mat();
2124 
2125         FileNode tcnode = fn["term_criteria"];
2126         if( !tcnode.empty() )
2127         {
2128             _params.termCrit.epsilon = (double)tcnode["epsilon"];
2129             _params.termCrit.maxCount = (int)tcnode["iterations"];
2130             _params.termCrit.type = (_params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
2131                                    (_params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
2132         }
2133         else
2134             _params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
2135 
2136         setParams( _params );
2137     }
2138 
read(const FileNode & fn)2139     void read( const FileNode& fn )
2140     {
2141         clear();
2142 
2143         // read SVM parameters
2144         read_params( fn );
2145 
2146         // and top-level data
2147         int i, sv_total = (int)fn["sv_total"];
2148         var_count = (int)fn["var_count"];
2149         int class_count = (int)fn["class_count"];
2150 
2151         if( sv_total <= 0 || var_count <= 0 )
2152             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2153 
2154         FileNode m = fn["class_labels"];
2155         if( !m.empty() )
2156             m >> class_labels;
2157         m = fn["class_weights"];
2158         if( !m.empty() )
2159             m >> params.classWeights;
2160 
2161         if( class_count > 1 && (class_labels.empty() || (int)class_labels.total() != class_count))
2162             CV_Error( CV_StsParseError, "Array of class labels is missing or invalid" );
2163 
2164         // read support vectors
2165         FileNode sv_node = fn["support_vectors"];
2166 
2167         CV_Assert((int)sv_node.size() == sv_total);
2168         sv.create(sv_total, var_count, CV_32F);
2169 
2170         FileNodeIterator sv_it = sv_node.begin();
2171         for( i = 0; i < sv_total; i++, ++sv_it )
2172         {
2173             (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
2174         }
2175 
2176         // read decision functions
2177         int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2178         FileNode df_node = fn["decision_functions"];
2179 
2180         CV_Assert((int)df_node.size() == df_count);
2181 
2182         FileNodeIterator df_it = df_node.begin();
2183         for( i = 0; i < df_count; i++, ++df_it )
2184         {
2185             FileNode dfi = *df_it;
2186             DecisionFunc df;
2187             int sv_count = (int)dfi["sv_count"];
2188             int ofs = (int)df_index.size();
2189             df.rho = (double)dfi["rho"];
2190             df.ofs = ofs;
2191             df_index.resize(ofs + sv_count);
2192             df_alpha.resize(ofs + sv_count);
2193             dfi["alpha"].readRaw("d", (uchar*)&df_alpha[ofs], sv_count*sizeof(df_alpha[0]));
2194             if( class_count > 2 )
2195                 dfi["index"].readRaw("i", (uchar*)&df_index[ofs], sv_count*sizeof(df_index[0]));
2196             decision_func.push_back(df);
2197         }
2198         if( class_count <= 2 )
2199             setRangeVector(df_index, sv_total);
2200         if( (int)fn["optimize_linear"] != 0 )
2201             optimize_linear_svm();
2202     }
2203 
2204     SvmParams params;
2205     Mat class_labels;
2206     int var_count;
2207     Mat sv;
2208     vector<DecisionFunc> decision_func;
2209     vector<double> df_alpha;
2210     vector<int> df_index;
2211 
2212     Ptr<Kernel> kernel;
2213 };
2214 
2215 
create()2216 Ptr<SVM> SVM::create()
2217 {
2218     return makePtr<SVMImpl>();
2219 }
2220 
2221 }
2222 }
2223 
2224 /* End of file. */
2225