1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41 
42 #include "precomp.hpp"
43 
44 namespace cv
45 {
46 namespace ml
47 {
48 
49 const double minEigenValue = DBL_EPSILON;
50 
51 class CV_EXPORTS EMImpl : public EM
52 {
53 public:
54 
55     int nclusters;
56     int covMatType;
57     TermCriteria termCrit;
58 
CV_IMPL_PROPERTY_S(TermCriteria,TermCriteria,termCrit)59     CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)
60 
61     void setClustersNumber(int val)
62     {
63         nclusters = val;
64         CV_Assert(nclusters > 1);
65     }
66 
getClustersNumber() const67     int getClustersNumber() const
68     {
69         return nclusters;
70     }
71 
setCovarianceMatrixType(int val)72     void setCovarianceMatrixType(int val)
73     {
74         covMatType = val;
75         CV_Assert(covMatType == COV_MAT_SPHERICAL ||
76                   covMatType == COV_MAT_DIAGONAL ||
77                   covMatType == COV_MAT_GENERIC);
78     }
79 
getCovarianceMatrixType() const80     int getCovarianceMatrixType() const
81     {
82         return covMatType;
83     }
84 
EMImpl()85     EMImpl()
86     {
87         nclusters = DEFAULT_NCLUSTERS;
88         covMatType=EM::COV_MAT_DIAGONAL;
89         termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
90     }
91 
~EMImpl()92     virtual ~EMImpl() {}
93 
clear()94     void clear()
95     {
96         trainSamples.release();
97         trainProbs.release();
98         trainLogLikelihoods.release();
99         trainLabels.release();
100 
101         weights.release();
102         means.release();
103         covs.clear();
104 
105         covsEigenValues.clear();
106         invCovsEigenValues.clear();
107         covsRotateMats.clear();
108 
109         logWeightDivDet.release();
110     }
111 
train(const Ptr<TrainData> & data,int)112     bool train(const Ptr<TrainData>& data, int)
113     {
114         Mat samples = data->getTrainSamples(), labels;
115         return trainEM(samples, labels, noArray(), noArray());
116     }
117 
trainEM(InputArray samples,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)118     bool trainEM(InputArray samples,
119                OutputArray logLikelihoods,
120                OutputArray labels,
121                OutputArray probs)
122     {
123         Mat samplesMat = samples.getMat();
124         setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
125         return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
126     }
127 
trainE(InputArray samples,InputArray _means0,InputArray _covs0,InputArray _weights0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)128     bool trainE(InputArray samples,
129                 InputArray _means0,
130                 InputArray _covs0,
131                 InputArray _weights0,
132                 OutputArray logLikelihoods,
133                 OutputArray labels,
134                 OutputArray probs)
135     {
136         Mat samplesMat = samples.getMat();
137         std::vector<Mat> covs0;
138         _covs0.getMatVector(covs0);
139 
140         Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
141 
142         setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
143                      !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
144         return doTrain(START_E_STEP, logLikelihoods, labels, probs);
145     }
146 
trainM(InputArray samples,InputArray _probs0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)147     bool trainM(InputArray samples,
148                 InputArray _probs0,
149                 OutputArray logLikelihoods,
150                 OutputArray labels,
151                 OutputArray probs)
152     {
153         Mat samplesMat = samples.getMat();
154         Mat probs0 = _probs0.getMat();
155 
156         setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
157         return doTrain(START_M_STEP, logLikelihoods, labels, probs);
158     }
159 
predict(InputArray _inputs,OutputArray _outputs,int) const160     float predict(InputArray _inputs, OutputArray _outputs, int) const
161     {
162         bool needprobs = _outputs.needed();
163         Mat samples = _inputs.getMat(), probs, probsrow;
164         int ptype = CV_32F;
165         float firstres = 0.f;
166         int i, nsamples = samples.rows;
167 
168         if( needprobs )
169         {
170             if( _outputs.fixedType() )
171                 ptype = _outputs.type();
172             _outputs.create(samples.rows, nclusters, ptype);
173         }
174         else
175             nsamples = std::min(nsamples, 1);
176 
177         for( i = 0; i < nsamples; i++ )
178         {
179             if( needprobs )
180                 probsrow = probs.row(i);
181             Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
182             if( i == 0 )
183                 firstres = (float)res[1];
184         }
185         return firstres;
186     }
187 
predict2(InputArray _sample,OutputArray _probs) const188     Vec2d predict2(InputArray _sample, OutputArray _probs) const
189     {
190         int ptype = CV_32F;
191         Mat sample = _sample.getMat();
192         CV_Assert(isTrained());
193 
194         CV_Assert(!sample.empty());
195         if(sample.type() != CV_64FC1)
196         {
197             Mat tmp;
198             sample.convertTo(tmp, CV_64FC1);
199             sample = tmp;
200         }
201         sample.reshape(1, 1);
202 
203         Mat probs;
204         if( _probs.needed() )
205         {
206             if( _probs.fixedType() )
207                 ptype = _probs.type();
208             _probs.create(1, nclusters, ptype);
209             probs = _probs.getMat();
210         }
211 
212         return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
213     }
214 
isTrained() const215     bool isTrained() const
216     {
217         return !means.empty();
218     }
219 
isClassifier() const220     bool isClassifier() const
221     {
222         return true;
223     }
224 
getVarCount() const225     int getVarCount() const
226     {
227         return means.cols;
228     }
229 
getDefaultName() const230     String getDefaultName() const
231     {
232         return "opencv_ml_em";
233     }
234 
checkTrainData(int startStep,const Mat & samples,int nclusters,int covMatType,const Mat * probs,const Mat * means,const std::vector<Mat> * covs,const Mat * weights)235     static void checkTrainData(int startStep, const Mat& samples,
236                                int nclusters, int covMatType, const Mat* probs, const Mat* means,
237                                const std::vector<Mat>* covs, const Mat* weights)
238     {
239         // Check samples.
240         CV_Assert(!samples.empty());
241         CV_Assert(samples.channels() == 1);
242 
243         int nsamples = samples.rows;
244         int dim = samples.cols;
245 
246         // Check training params.
247         CV_Assert(nclusters > 0);
248         CV_Assert(nclusters <= nsamples);
249         CV_Assert(startStep == START_AUTO_STEP ||
250                   startStep == START_E_STEP ||
251                   startStep == START_M_STEP);
252         CV_Assert(covMatType == COV_MAT_GENERIC ||
253                   covMatType == COV_MAT_DIAGONAL ||
254                   covMatType == COV_MAT_SPHERICAL);
255 
256         CV_Assert(!probs ||
257             (!probs->empty() &&
258              probs->rows == nsamples && probs->cols == nclusters &&
259              (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
260 
261         CV_Assert(!weights ||
262             (!weights->empty() &&
263              (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
264              (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
265 
266         CV_Assert(!means ||
267             (!means->empty() &&
268              means->rows == nclusters && means->cols == dim &&
269              means->channels() == 1));
270 
271         CV_Assert(!covs ||
272             (!covs->empty() &&
273              static_cast<int>(covs->size()) == nclusters));
274         if(covs)
275         {
276             const Size covSize(dim, dim);
277             for(size_t i = 0; i < covs->size(); i++)
278             {
279                 const Mat& m = (*covs)[i];
280                 CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
281             }
282         }
283 
284         if(startStep == START_E_STEP)
285         {
286             CV_Assert(means);
287         }
288         else if(startStep == START_M_STEP)
289         {
290             CV_Assert(probs);
291         }
292     }
293 
preprocessSampleData(const Mat & src,Mat & dst,int dstType,bool isAlwaysClone)294     static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
295     {
296         if(src.type() == dstType && !isAlwaysClone)
297             dst = src;
298         else
299             src.convertTo(dst, dstType);
300     }
301 
preprocessProbability(Mat & probs)302     static void preprocessProbability(Mat& probs)
303     {
304         max(probs, 0., probs);
305 
306         const double uniformProbability = (double)(1./probs.cols);
307         for(int y = 0; y < probs.rows; y++)
308         {
309             Mat sampleProbs = probs.row(y);
310 
311             double maxVal = 0;
312             minMaxLoc(sampleProbs, 0, &maxVal);
313             if(maxVal < FLT_EPSILON)
314                 sampleProbs.setTo(uniformProbability);
315             else
316                 normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
317         }
318     }
319 
setTrainData(int startStep,const Mat & samples,const Mat * probs0,const Mat * means0,const std::vector<Mat> * covs0,const Mat * weights0)320     void setTrainData(int startStep, const Mat& samples,
321                       const Mat* probs0,
322                       const Mat* means0,
323                       const std::vector<Mat>* covs0,
324                       const Mat* weights0)
325     {
326         clear();
327 
328         checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
329 
330         bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
331         // Set checked data
332         preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
333 
334         // set probs
335         if(probs0 && startStep == START_M_STEP)
336         {
337             preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
338             preprocessProbability(trainProbs);
339         }
340 
341         // set weights
342         if(weights0 && (startStep == START_E_STEP && covs0))
343         {
344             weights0->convertTo(weights, CV_64FC1);
345             weights.reshape(1,1);
346             preprocessProbability(weights);
347         }
348 
349         // set means
350         if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
351             means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
352 
353         // set covs
354         if(covs0 && (startStep == START_E_STEP && weights0))
355         {
356             covs.resize(nclusters);
357             for(size_t i = 0; i < covs0->size(); i++)
358                 (*covs0)[i].convertTo(covs[i], CV_64FC1);
359         }
360     }
361 
decomposeCovs()362     void decomposeCovs()
363     {
364         CV_Assert(!covs.empty());
365         covsEigenValues.resize(nclusters);
366         if(covMatType == COV_MAT_GENERIC)
367             covsRotateMats.resize(nclusters);
368         invCovsEigenValues.resize(nclusters);
369         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
370         {
371             CV_Assert(!covs[clusterIndex].empty());
372 
373             SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
374 
375             if(covMatType == COV_MAT_SPHERICAL)
376             {
377                 double maxSingularVal = svd.w.at<double>(0);
378                 covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
379             }
380             else if(covMatType == COV_MAT_DIAGONAL)
381             {
382                 covsEigenValues[clusterIndex] = svd.w;
383             }
384             else //COV_MAT_GENERIC
385             {
386                 covsEigenValues[clusterIndex] = svd.w;
387                 covsRotateMats[clusterIndex] = svd.u;
388             }
389             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
390             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
391         }
392     }
393 
clusterTrainSamples()394     void clusterTrainSamples()
395     {
396         int nsamples = trainSamples.rows;
397 
398         // Cluster samples, compute/update means
399 
400         // Convert samples and means to 32F, because kmeans requires this type.
401         Mat trainSamplesFlt, meansFlt;
402         if(trainSamples.type() != CV_32FC1)
403             trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
404         else
405             trainSamplesFlt = trainSamples;
406         if(!means.empty())
407         {
408             if(means.type() != CV_32FC1)
409                 means.convertTo(meansFlt, CV_32FC1);
410             else
411                 meansFlt = means;
412         }
413 
414         Mat labels;
415         kmeans(trainSamplesFlt, nclusters, labels,
416                TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
417                10, KMEANS_PP_CENTERS, meansFlt);
418 
419         // Convert samples and means back to 64F.
420         CV_Assert(meansFlt.type() == CV_32FC1);
421         if(trainSamples.type() != CV_64FC1)
422         {
423             Mat trainSamplesBuffer;
424             trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
425             trainSamples = trainSamplesBuffer;
426         }
427         meansFlt.convertTo(means, CV_64FC1);
428 
429         // Compute weights and covs
430         weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
431         covs.resize(nclusters);
432         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
433         {
434             Mat clusterSamples;
435             for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
436             {
437                 if(labels.at<int>(sampleIndex) == clusterIndex)
438                 {
439                     const Mat sample = trainSamples.row(sampleIndex);
440                     clusterSamples.push_back(sample);
441                 }
442             }
443             CV_Assert(!clusterSamples.empty());
444 
445             calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
446                 CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
447             weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
448         }
449 
450         decomposeCovs();
451     }
452 
computeLogWeightDivDet()453     void computeLogWeightDivDet()
454     {
455         CV_Assert(!covsEigenValues.empty());
456 
457         Mat logWeights;
458         cv::max(weights, DBL_MIN, weights);
459         log(weights, logWeights);
460 
461         logWeightDivDet.create(1, nclusters, CV_64FC1);
462         // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
463 
464         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
465         {
466             double logDetCov = 0.;
467             const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
468             for(int di = 0; di < evalCount; di++)
469                 logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
470 
471             logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
472         }
473     }
474 
doTrain(int startStep,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)475     bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
476     {
477         int dim = trainSamples.cols;
478         // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
479         if(startStep != START_M_STEP)
480         {
481             if(covs.empty())
482             {
483                 CV_Assert(weights.empty());
484                 clusterTrainSamples();
485             }
486         }
487 
488         if(!covs.empty() && covsEigenValues.empty() )
489         {
490             CV_Assert(invCovsEigenValues.empty());
491             decomposeCovs();
492         }
493 
494         if(startStep == START_M_STEP)
495             mStep();
496 
497         double trainLogLikelihood, prevTrainLogLikelihood = 0.;
498         int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
499             termCrit.maxCount : DEFAULT_MAX_ITERS;
500         double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
501 
502         for(int iter = 0; ; iter++)
503         {
504             eStep();
505             trainLogLikelihood = sum(trainLogLikelihoods)[0];
506 
507             if(iter >= maxIters - 1)
508                 break;
509 
510             double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
511             if( iter != 0 &&
512                 (trainLogLikelihoodDelta < -DBL_EPSILON ||
513                  trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
514                 break;
515 
516             mStep();
517 
518             prevTrainLogLikelihood = trainLogLikelihood;
519         }
520 
521         if( trainLogLikelihood <= -DBL_MAX/10000. )
522         {
523             clear();
524             return false;
525         }
526 
527         // postprocess covs
528         covs.resize(nclusters);
529         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
530         {
531             if(covMatType == COV_MAT_SPHERICAL)
532             {
533                 covs[clusterIndex].create(dim, dim, CV_64FC1);
534                 setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
535             }
536             else if(covMatType == COV_MAT_DIAGONAL)
537             {
538                 covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
539             }
540         }
541 
542         if(labels.needed())
543             trainLabels.copyTo(labels);
544         if(probs.needed())
545             trainProbs.copyTo(probs);
546         if(logLikelihoods.needed())
547             trainLogLikelihoods.copyTo(logLikelihoods);
548 
549         trainSamples.release();
550         trainProbs.release();
551         trainLabels.release();
552         trainLogLikelihoods.release();
553 
554         return true;
555     }
556 
computeProbabilities(const Mat & sample,Mat * probs,int ptype) const557     Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
558     {
559         // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
560         // q = arg(max_k(L_ik))
561         // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
562         // see Alex Smola's blog http://blog.smola.org/page/2 for
563         // details on the log-sum-exp trick
564 
565         int stype = sample.type();
566         CV_Assert(!means.empty());
567         CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
568         CV_Assert(sample.size() == Size(means.cols, 1));
569 
570         int dim = sample.cols;
571 
572         Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
573         int i, label = 0;
574         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
575         {
576             const double* mptr = means.ptr<double>(clusterIndex);
577             double* dptr = centeredSample.ptr<double>();
578             if( stype == CV_32F )
579             {
580                 const float* sptr = sample.ptr<float>();
581                 for( i = 0; i < dim; i++ )
582                     dptr[i] = sptr[i] - mptr[i];
583             }
584             else
585             {
586                 const double* sptr = sample.ptr<double>();
587                 for( i = 0; i < dim; i++ )
588                     dptr[i] = sptr[i] - mptr[i];
589             }
590 
591             Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
592                     centeredSample : centeredSample * covsRotateMats[clusterIndex];
593 
594             double Lval = 0;
595             for(int di = 0; di < dim; di++)
596             {
597                 double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
598                 double val = rotatedCenteredSample.at<double>(di);
599                 Lval += w * val * val;
600             }
601             CV_DbgAssert(!logWeightDivDet.empty());
602             L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
603 
604             if(L.at<double>(clusterIndex) > L.at<double>(label))
605                 label = clusterIndex;
606         }
607 
608         double maxLVal = L.at<double>(label);
609         double expDiffSum = 0;
610         for( i = 0; i < L.cols; i++ )
611         {
612             double v = std::exp(L.at<double>(i) - maxLVal);
613             L.at<double>(i) = v;
614             expDiffSum += v; // sum_j(exp(L_ij - L_iq))
615         }
616 
617         if(probs)
618             L.convertTo(*probs, ptype, 1./expDiffSum);
619 
620         Vec2d res;
621         res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
622         res[1] = label;
623 
624         return res;
625     }
626 
eStep()627     void eStep()
628     {
629         // Compute probs_ik from means_k, covs_k and weights_k.
630         trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
631         trainLabels.create(trainSamples.rows, 1, CV_32SC1);
632         trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
633 
634         computeLogWeightDivDet();
635 
636         CV_DbgAssert(trainSamples.type() == CV_64FC1);
637         CV_DbgAssert(means.type() == CV_64FC1);
638 
639         for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
640         {
641             Mat sampleProbs = trainProbs.row(sampleIndex);
642             Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
643             trainLogLikelihoods.at<double>(sampleIndex) = res[0];
644             trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
645         }
646     }
647 
mStep()648     void mStep()
649     {
650         // Update means_k, covs_k and weights_k from probs_ik
651         int dim = trainSamples.cols;
652 
653         // Update weights
654         // not normalized first
655         reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
656 
657         // Update means
658         means.create(nclusters, dim, CV_64FC1);
659         means = Scalar(0);
660 
661         const double minPosWeight = trainSamples.rows * DBL_EPSILON;
662         double minWeight = DBL_MAX;
663         int minWeightClusterIndex = -1;
664         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
665         {
666             if(weights.at<double>(clusterIndex) <= minPosWeight)
667                 continue;
668 
669             if(weights.at<double>(clusterIndex) < minWeight)
670             {
671                 minWeight = weights.at<double>(clusterIndex);
672                 minWeightClusterIndex = clusterIndex;
673             }
674 
675             Mat clusterMean = means.row(clusterIndex);
676             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
677                 clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
678             clusterMean /= weights.at<double>(clusterIndex);
679         }
680 
681         // Update covsEigenValues and invCovsEigenValues
682         covs.resize(nclusters);
683         covsEigenValues.resize(nclusters);
684         if(covMatType == COV_MAT_GENERIC)
685             covsRotateMats.resize(nclusters);
686         invCovsEigenValues.resize(nclusters);
687         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
688         {
689             if(weights.at<double>(clusterIndex) <= minPosWeight)
690                 continue;
691 
692             if(covMatType != COV_MAT_SPHERICAL)
693                 covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
694             else
695                 covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
696 
697             if(covMatType == COV_MAT_GENERIC)
698                 covs[clusterIndex].create(dim, dim, CV_64FC1);
699 
700             Mat clusterCov = covMatType != COV_MAT_GENERIC ?
701                 covsEigenValues[clusterIndex] : covs[clusterIndex];
702 
703             clusterCov = Scalar(0);
704 
705             Mat centeredSample;
706             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
707             {
708                 centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
709 
710                 if(covMatType == COV_MAT_GENERIC)
711                     clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
712                 else
713                 {
714                     double p = trainProbs.at<double>(sampleIndex, clusterIndex);
715                     for(int di = 0; di < dim; di++ )
716                     {
717                         double val = centeredSample.at<double>(di);
718                         clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
719                     }
720                 }
721             }
722 
723             if(covMatType == COV_MAT_SPHERICAL)
724                 clusterCov /= dim;
725 
726             clusterCov /= weights.at<double>(clusterIndex);
727 
728             // Update covsRotateMats for COV_MAT_GENERIC only
729             if(covMatType == COV_MAT_GENERIC)
730             {
731                 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
732                 covsEigenValues[clusterIndex] = svd.w;
733                 covsRotateMats[clusterIndex] = svd.u;
734             }
735 
736             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
737 
738             // update invCovsEigenValues
739             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
740         }
741 
742         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
743         {
744             if(weights.at<double>(clusterIndex) <= minPosWeight)
745             {
746                 Mat clusterMean = means.row(clusterIndex);
747                 means.row(minWeightClusterIndex).copyTo(clusterMean);
748                 covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
749                 covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
750                 if(covMatType == COV_MAT_GENERIC)
751                     covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
752                 invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
753             }
754         }
755 
756         // Normalize weights
757         weights /= trainSamples.rows;
758     }
759 
write_params(FileStorage & fs) const760     void write_params(FileStorage& fs) const
761     {
762         fs << "nclusters" << nclusters;
763         fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
764                                  covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
765                                  covMatType == COV_MAT_GENERIC ? String("generic") :
766                                  format("unknown_%d", covMatType));
767         writeTermCrit(fs, termCrit);
768     }
769 
write(FileStorage & fs) const770     void write(FileStorage& fs) const
771     {
772         fs << "training_params" << "{";
773         write_params(fs);
774         fs << "}";
775         fs << "weights" << weights;
776         fs << "means" << means;
777 
778         size_t i, n = covs.size();
779 
780         fs << "covs" << "[";
781         for( i = 0; i < n; i++ )
782             fs << covs[i];
783         fs << "]";
784     }
785 
read_params(const FileNode & fn)786     void read_params(const FileNode& fn)
787     {
788         nclusters = (int)fn["nclusters"];
789         String s = (String)fn["cov_mat_type"];
790         covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
791                              s == "diagonal" ? COV_MAT_DIAGONAL :
792                              s == "generic" ? COV_MAT_GENERIC : -1;
793         CV_Assert(covMatType >= 0);
794         termCrit = readTermCrit(fn);
795     }
796 
read(const FileNode & fn)797     void read(const FileNode& fn)
798     {
799         clear();
800         read_params(fn["training_params"]);
801 
802         fn["weights"] >> weights;
803         fn["means"] >> means;
804 
805         FileNode cfn = fn["covs"];
806         FileNodeIterator cfn_it = cfn.begin();
807         int i, n = (int)cfn.size();
808         covs.resize(n);
809 
810         for( i = 0; i < n; i++, ++cfn_it )
811             (*cfn_it) >> covs[i];
812 
813         decomposeCovs();
814         computeLogWeightDivDet();
815     }
816 
getWeights() const817     Mat getWeights() const { return weights; }
getMeans() const818     Mat getMeans() const { return means; }
getCovs(std::vector<Mat> & _covs) const819     void getCovs(std::vector<Mat>& _covs) const
820     {
821         _covs.resize(covs.size());
822         std::copy(covs.begin(), covs.end(), _covs.begin());
823     }
824 
825     // all inner matrices have type CV_64FC1
826     Mat trainSamples;
827     Mat trainProbs;
828     Mat trainLogLikelihoods;
829     Mat trainLabels;
830 
831     Mat weights;
832     Mat means;
833     std::vector<Mat> covs;
834 
835     std::vector<Mat> covsEigenValues;
836     std::vector<Mat> covsRotateMats;
837     std::vector<Mat> invCovsEigenValues;
838     Mat logWeightDivDet;
839 };
840 
create()841 Ptr<EM> EM::create()
842 {
843     return makePtr<EMImpl>();
844 }
845 
846 }
847 } // namespace cv
848 
849 /* End of file. */
850