1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
21 //
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
25 //
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "precomp.hpp"
43
44 namespace cv
45 {
46 namespace ml
47 {
48
49 const double minEigenValue = DBL_EPSILON;
50
51 class CV_EXPORTS EMImpl : public EM
52 {
53 public:
54
55 int nclusters;
56 int covMatType;
57 TermCriteria termCrit;
58
CV_IMPL_PROPERTY_S(TermCriteria,TermCriteria,termCrit)59 CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)
60
61 void setClustersNumber(int val)
62 {
63 nclusters = val;
64 CV_Assert(nclusters > 1);
65 }
66
getClustersNumber() const67 int getClustersNumber() const
68 {
69 return nclusters;
70 }
71
setCovarianceMatrixType(int val)72 void setCovarianceMatrixType(int val)
73 {
74 covMatType = val;
75 CV_Assert(covMatType == COV_MAT_SPHERICAL ||
76 covMatType == COV_MAT_DIAGONAL ||
77 covMatType == COV_MAT_GENERIC);
78 }
79
getCovarianceMatrixType() const80 int getCovarianceMatrixType() const
81 {
82 return covMatType;
83 }
84
EMImpl()85 EMImpl()
86 {
87 nclusters = DEFAULT_NCLUSTERS;
88 covMatType=EM::COV_MAT_DIAGONAL;
89 termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
90 }
91
~EMImpl()92 virtual ~EMImpl() {}
93
clear()94 void clear()
95 {
96 trainSamples.release();
97 trainProbs.release();
98 trainLogLikelihoods.release();
99 trainLabels.release();
100
101 weights.release();
102 means.release();
103 covs.clear();
104
105 covsEigenValues.clear();
106 invCovsEigenValues.clear();
107 covsRotateMats.clear();
108
109 logWeightDivDet.release();
110 }
111
train(const Ptr<TrainData> & data,int)112 bool train(const Ptr<TrainData>& data, int)
113 {
114 Mat samples = data->getTrainSamples(), labels;
115 return trainEM(samples, labels, noArray(), noArray());
116 }
117
trainEM(InputArray samples,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)118 bool trainEM(InputArray samples,
119 OutputArray logLikelihoods,
120 OutputArray labels,
121 OutputArray probs)
122 {
123 Mat samplesMat = samples.getMat();
124 setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
125 return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
126 }
127
trainE(InputArray samples,InputArray _means0,InputArray _covs0,InputArray _weights0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)128 bool trainE(InputArray samples,
129 InputArray _means0,
130 InputArray _covs0,
131 InputArray _weights0,
132 OutputArray logLikelihoods,
133 OutputArray labels,
134 OutputArray probs)
135 {
136 Mat samplesMat = samples.getMat();
137 std::vector<Mat> covs0;
138 _covs0.getMatVector(covs0);
139
140 Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
141
142 setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
143 !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
144 return doTrain(START_E_STEP, logLikelihoods, labels, probs);
145 }
146
trainM(InputArray samples,InputArray _probs0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)147 bool trainM(InputArray samples,
148 InputArray _probs0,
149 OutputArray logLikelihoods,
150 OutputArray labels,
151 OutputArray probs)
152 {
153 Mat samplesMat = samples.getMat();
154 Mat probs0 = _probs0.getMat();
155
156 setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
157 return doTrain(START_M_STEP, logLikelihoods, labels, probs);
158 }
159
predict(InputArray _inputs,OutputArray _outputs,int) const160 float predict(InputArray _inputs, OutputArray _outputs, int) const
161 {
162 bool needprobs = _outputs.needed();
163 Mat samples = _inputs.getMat(), probs, probsrow;
164 int ptype = CV_32F;
165 float firstres = 0.f;
166 int i, nsamples = samples.rows;
167
168 if( needprobs )
169 {
170 if( _outputs.fixedType() )
171 ptype = _outputs.type();
172 _outputs.create(samples.rows, nclusters, ptype);
173 }
174 else
175 nsamples = std::min(nsamples, 1);
176
177 for( i = 0; i < nsamples; i++ )
178 {
179 if( needprobs )
180 probsrow = probs.row(i);
181 Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
182 if( i == 0 )
183 firstres = (float)res[1];
184 }
185 return firstres;
186 }
187
predict2(InputArray _sample,OutputArray _probs) const188 Vec2d predict2(InputArray _sample, OutputArray _probs) const
189 {
190 int ptype = CV_32F;
191 Mat sample = _sample.getMat();
192 CV_Assert(isTrained());
193
194 CV_Assert(!sample.empty());
195 if(sample.type() != CV_64FC1)
196 {
197 Mat tmp;
198 sample.convertTo(tmp, CV_64FC1);
199 sample = tmp;
200 }
201 sample.reshape(1, 1);
202
203 Mat probs;
204 if( _probs.needed() )
205 {
206 if( _probs.fixedType() )
207 ptype = _probs.type();
208 _probs.create(1, nclusters, ptype);
209 probs = _probs.getMat();
210 }
211
212 return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
213 }
214
isTrained() const215 bool isTrained() const
216 {
217 return !means.empty();
218 }
219
isClassifier() const220 bool isClassifier() const
221 {
222 return true;
223 }
224
getVarCount() const225 int getVarCount() const
226 {
227 return means.cols;
228 }
229
getDefaultName() const230 String getDefaultName() const
231 {
232 return "opencv_ml_em";
233 }
234
checkTrainData(int startStep,const Mat & samples,int nclusters,int covMatType,const Mat * probs,const Mat * means,const std::vector<Mat> * covs,const Mat * weights)235 static void checkTrainData(int startStep, const Mat& samples,
236 int nclusters, int covMatType, const Mat* probs, const Mat* means,
237 const std::vector<Mat>* covs, const Mat* weights)
238 {
239 // Check samples.
240 CV_Assert(!samples.empty());
241 CV_Assert(samples.channels() == 1);
242
243 int nsamples = samples.rows;
244 int dim = samples.cols;
245
246 // Check training params.
247 CV_Assert(nclusters > 0);
248 CV_Assert(nclusters <= nsamples);
249 CV_Assert(startStep == START_AUTO_STEP ||
250 startStep == START_E_STEP ||
251 startStep == START_M_STEP);
252 CV_Assert(covMatType == COV_MAT_GENERIC ||
253 covMatType == COV_MAT_DIAGONAL ||
254 covMatType == COV_MAT_SPHERICAL);
255
256 CV_Assert(!probs ||
257 (!probs->empty() &&
258 probs->rows == nsamples && probs->cols == nclusters &&
259 (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
260
261 CV_Assert(!weights ||
262 (!weights->empty() &&
263 (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
264 (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
265
266 CV_Assert(!means ||
267 (!means->empty() &&
268 means->rows == nclusters && means->cols == dim &&
269 means->channels() == 1));
270
271 CV_Assert(!covs ||
272 (!covs->empty() &&
273 static_cast<int>(covs->size()) == nclusters));
274 if(covs)
275 {
276 const Size covSize(dim, dim);
277 for(size_t i = 0; i < covs->size(); i++)
278 {
279 const Mat& m = (*covs)[i];
280 CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
281 }
282 }
283
284 if(startStep == START_E_STEP)
285 {
286 CV_Assert(means);
287 }
288 else if(startStep == START_M_STEP)
289 {
290 CV_Assert(probs);
291 }
292 }
293
preprocessSampleData(const Mat & src,Mat & dst,int dstType,bool isAlwaysClone)294 static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
295 {
296 if(src.type() == dstType && !isAlwaysClone)
297 dst = src;
298 else
299 src.convertTo(dst, dstType);
300 }
301
preprocessProbability(Mat & probs)302 static void preprocessProbability(Mat& probs)
303 {
304 max(probs, 0., probs);
305
306 const double uniformProbability = (double)(1./probs.cols);
307 for(int y = 0; y < probs.rows; y++)
308 {
309 Mat sampleProbs = probs.row(y);
310
311 double maxVal = 0;
312 minMaxLoc(sampleProbs, 0, &maxVal);
313 if(maxVal < FLT_EPSILON)
314 sampleProbs.setTo(uniformProbability);
315 else
316 normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
317 }
318 }
319
setTrainData(int startStep,const Mat & samples,const Mat * probs0,const Mat * means0,const std::vector<Mat> * covs0,const Mat * weights0)320 void setTrainData(int startStep, const Mat& samples,
321 const Mat* probs0,
322 const Mat* means0,
323 const std::vector<Mat>* covs0,
324 const Mat* weights0)
325 {
326 clear();
327
328 checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
329
330 bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
331 // Set checked data
332 preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
333
334 // set probs
335 if(probs0 && startStep == START_M_STEP)
336 {
337 preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
338 preprocessProbability(trainProbs);
339 }
340
341 // set weights
342 if(weights0 && (startStep == START_E_STEP && covs0))
343 {
344 weights0->convertTo(weights, CV_64FC1);
345 weights.reshape(1,1);
346 preprocessProbability(weights);
347 }
348
349 // set means
350 if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
351 means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
352
353 // set covs
354 if(covs0 && (startStep == START_E_STEP && weights0))
355 {
356 covs.resize(nclusters);
357 for(size_t i = 0; i < covs0->size(); i++)
358 (*covs0)[i].convertTo(covs[i], CV_64FC1);
359 }
360 }
361
decomposeCovs()362 void decomposeCovs()
363 {
364 CV_Assert(!covs.empty());
365 covsEigenValues.resize(nclusters);
366 if(covMatType == COV_MAT_GENERIC)
367 covsRotateMats.resize(nclusters);
368 invCovsEigenValues.resize(nclusters);
369 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
370 {
371 CV_Assert(!covs[clusterIndex].empty());
372
373 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
374
375 if(covMatType == COV_MAT_SPHERICAL)
376 {
377 double maxSingularVal = svd.w.at<double>(0);
378 covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
379 }
380 else if(covMatType == COV_MAT_DIAGONAL)
381 {
382 covsEigenValues[clusterIndex] = svd.w;
383 }
384 else //COV_MAT_GENERIC
385 {
386 covsEigenValues[clusterIndex] = svd.w;
387 covsRotateMats[clusterIndex] = svd.u;
388 }
389 max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
390 invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
391 }
392 }
393
clusterTrainSamples()394 void clusterTrainSamples()
395 {
396 int nsamples = trainSamples.rows;
397
398 // Cluster samples, compute/update means
399
400 // Convert samples and means to 32F, because kmeans requires this type.
401 Mat trainSamplesFlt, meansFlt;
402 if(trainSamples.type() != CV_32FC1)
403 trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
404 else
405 trainSamplesFlt = trainSamples;
406 if(!means.empty())
407 {
408 if(means.type() != CV_32FC1)
409 means.convertTo(meansFlt, CV_32FC1);
410 else
411 meansFlt = means;
412 }
413
414 Mat labels;
415 kmeans(trainSamplesFlt, nclusters, labels,
416 TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
417 10, KMEANS_PP_CENTERS, meansFlt);
418
419 // Convert samples and means back to 64F.
420 CV_Assert(meansFlt.type() == CV_32FC1);
421 if(trainSamples.type() != CV_64FC1)
422 {
423 Mat trainSamplesBuffer;
424 trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
425 trainSamples = trainSamplesBuffer;
426 }
427 meansFlt.convertTo(means, CV_64FC1);
428
429 // Compute weights and covs
430 weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
431 covs.resize(nclusters);
432 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
433 {
434 Mat clusterSamples;
435 for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
436 {
437 if(labels.at<int>(sampleIndex) == clusterIndex)
438 {
439 const Mat sample = trainSamples.row(sampleIndex);
440 clusterSamples.push_back(sample);
441 }
442 }
443 CV_Assert(!clusterSamples.empty());
444
445 calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
446 CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
447 weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
448 }
449
450 decomposeCovs();
451 }
452
computeLogWeightDivDet()453 void computeLogWeightDivDet()
454 {
455 CV_Assert(!covsEigenValues.empty());
456
457 Mat logWeights;
458 cv::max(weights, DBL_MIN, weights);
459 log(weights, logWeights);
460
461 logWeightDivDet.create(1, nclusters, CV_64FC1);
462 // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
463
464 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
465 {
466 double logDetCov = 0.;
467 const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
468 for(int di = 0; di < evalCount; di++)
469 logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
470
471 logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
472 }
473 }
474
doTrain(int startStep,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)475 bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
476 {
477 int dim = trainSamples.cols;
478 // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
479 if(startStep != START_M_STEP)
480 {
481 if(covs.empty())
482 {
483 CV_Assert(weights.empty());
484 clusterTrainSamples();
485 }
486 }
487
488 if(!covs.empty() && covsEigenValues.empty() )
489 {
490 CV_Assert(invCovsEigenValues.empty());
491 decomposeCovs();
492 }
493
494 if(startStep == START_M_STEP)
495 mStep();
496
497 double trainLogLikelihood, prevTrainLogLikelihood = 0.;
498 int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
499 termCrit.maxCount : DEFAULT_MAX_ITERS;
500 double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
501
502 for(int iter = 0; ; iter++)
503 {
504 eStep();
505 trainLogLikelihood = sum(trainLogLikelihoods)[0];
506
507 if(iter >= maxIters - 1)
508 break;
509
510 double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
511 if( iter != 0 &&
512 (trainLogLikelihoodDelta < -DBL_EPSILON ||
513 trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
514 break;
515
516 mStep();
517
518 prevTrainLogLikelihood = trainLogLikelihood;
519 }
520
521 if( trainLogLikelihood <= -DBL_MAX/10000. )
522 {
523 clear();
524 return false;
525 }
526
527 // postprocess covs
528 covs.resize(nclusters);
529 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
530 {
531 if(covMatType == COV_MAT_SPHERICAL)
532 {
533 covs[clusterIndex].create(dim, dim, CV_64FC1);
534 setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
535 }
536 else if(covMatType == COV_MAT_DIAGONAL)
537 {
538 covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
539 }
540 }
541
542 if(labels.needed())
543 trainLabels.copyTo(labels);
544 if(probs.needed())
545 trainProbs.copyTo(probs);
546 if(logLikelihoods.needed())
547 trainLogLikelihoods.copyTo(logLikelihoods);
548
549 trainSamples.release();
550 trainProbs.release();
551 trainLabels.release();
552 trainLogLikelihoods.release();
553
554 return true;
555 }
556
computeProbabilities(const Mat & sample,Mat * probs,int ptype) const557 Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
558 {
559 // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
560 // q = arg(max_k(L_ik))
561 // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
562 // see Alex Smola's blog http://blog.smola.org/page/2 for
563 // details on the log-sum-exp trick
564
565 int stype = sample.type();
566 CV_Assert(!means.empty());
567 CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
568 CV_Assert(sample.size() == Size(means.cols, 1));
569
570 int dim = sample.cols;
571
572 Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
573 int i, label = 0;
574 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
575 {
576 const double* mptr = means.ptr<double>(clusterIndex);
577 double* dptr = centeredSample.ptr<double>();
578 if( stype == CV_32F )
579 {
580 const float* sptr = sample.ptr<float>();
581 for( i = 0; i < dim; i++ )
582 dptr[i] = sptr[i] - mptr[i];
583 }
584 else
585 {
586 const double* sptr = sample.ptr<double>();
587 for( i = 0; i < dim; i++ )
588 dptr[i] = sptr[i] - mptr[i];
589 }
590
591 Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
592 centeredSample : centeredSample * covsRotateMats[clusterIndex];
593
594 double Lval = 0;
595 for(int di = 0; di < dim; di++)
596 {
597 double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
598 double val = rotatedCenteredSample.at<double>(di);
599 Lval += w * val * val;
600 }
601 CV_DbgAssert(!logWeightDivDet.empty());
602 L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
603
604 if(L.at<double>(clusterIndex) > L.at<double>(label))
605 label = clusterIndex;
606 }
607
608 double maxLVal = L.at<double>(label);
609 double expDiffSum = 0;
610 for( i = 0; i < L.cols; i++ )
611 {
612 double v = std::exp(L.at<double>(i) - maxLVal);
613 L.at<double>(i) = v;
614 expDiffSum += v; // sum_j(exp(L_ij - L_iq))
615 }
616
617 if(probs)
618 L.convertTo(*probs, ptype, 1./expDiffSum);
619
620 Vec2d res;
621 res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
622 res[1] = label;
623
624 return res;
625 }
626
eStep()627 void eStep()
628 {
629 // Compute probs_ik from means_k, covs_k and weights_k.
630 trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
631 trainLabels.create(trainSamples.rows, 1, CV_32SC1);
632 trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
633
634 computeLogWeightDivDet();
635
636 CV_DbgAssert(trainSamples.type() == CV_64FC1);
637 CV_DbgAssert(means.type() == CV_64FC1);
638
639 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
640 {
641 Mat sampleProbs = trainProbs.row(sampleIndex);
642 Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
643 trainLogLikelihoods.at<double>(sampleIndex) = res[0];
644 trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
645 }
646 }
647
mStep()648 void mStep()
649 {
650 // Update means_k, covs_k and weights_k from probs_ik
651 int dim = trainSamples.cols;
652
653 // Update weights
654 // not normalized first
655 reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
656
657 // Update means
658 means.create(nclusters, dim, CV_64FC1);
659 means = Scalar(0);
660
661 const double minPosWeight = trainSamples.rows * DBL_EPSILON;
662 double minWeight = DBL_MAX;
663 int minWeightClusterIndex = -1;
664 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
665 {
666 if(weights.at<double>(clusterIndex) <= minPosWeight)
667 continue;
668
669 if(weights.at<double>(clusterIndex) < minWeight)
670 {
671 minWeight = weights.at<double>(clusterIndex);
672 minWeightClusterIndex = clusterIndex;
673 }
674
675 Mat clusterMean = means.row(clusterIndex);
676 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
677 clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
678 clusterMean /= weights.at<double>(clusterIndex);
679 }
680
681 // Update covsEigenValues and invCovsEigenValues
682 covs.resize(nclusters);
683 covsEigenValues.resize(nclusters);
684 if(covMatType == COV_MAT_GENERIC)
685 covsRotateMats.resize(nclusters);
686 invCovsEigenValues.resize(nclusters);
687 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
688 {
689 if(weights.at<double>(clusterIndex) <= minPosWeight)
690 continue;
691
692 if(covMatType != COV_MAT_SPHERICAL)
693 covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
694 else
695 covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
696
697 if(covMatType == COV_MAT_GENERIC)
698 covs[clusterIndex].create(dim, dim, CV_64FC1);
699
700 Mat clusterCov = covMatType != COV_MAT_GENERIC ?
701 covsEigenValues[clusterIndex] : covs[clusterIndex];
702
703 clusterCov = Scalar(0);
704
705 Mat centeredSample;
706 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
707 {
708 centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
709
710 if(covMatType == COV_MAT_GENERIC)
711 clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
712 else
713 {
714 double p = trainProbs.at<double>(sampleIndex, clusterIndex);
715 for(int di = 0; di < dim; di++ )
716 {
717 double val = centeredSample.at<double>(di);
718 clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
719 }
720 }
721 }
722
723 if(covMatType == COV_MAT_SPHERICAL)
724 clusterCov /= dim;
725
726 clusterCov /= weights.at<double>(clusterIndex);
727
728 // Update covsRotateMats for COV_MAT_GENERIC only
729 if(covMatType == COV_MAT_GENERIC)
730 {
731 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
732 covsEigenValues[clusterIndex] = svd.w;
733 covsRotateMats[clusterIndex] = svd.u;
734 }
735
736 max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
737
738 // update invCovsEigenValues
739 invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
740 }
741
742 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
743 {
744 if(weights.at<double>(clusterIndex) <= minPosWeight)
745 {
746 Mat clusterMean = means.row(clusterIndex);
747 means.row(minWeightClusterIndex).copyTo(clusterMean);
748 covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
749 covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
750 if(covMatType == COV_MAT_GENERIC)
751 covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
752 invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
753 }
754 }
755
756 // Normalize weights
757 weights /= trainSamples.rows;
758 }
759
write_params(FileStorage & fs) const760 void write_params(FileStorage& fs) const
761 {
762 fs << "nclusters" << nclusters;
763 fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
764 covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
765 covMatType == COV_MAT_GENERIC ? String("generic") :
766 format("unknown_%d", covMatType));
767 writeTermCrit(fs, termCrit);
768 }
769
write(FileStorage & fs) const770 void write(FileStorage& fs) const
771 {
772 fs << "training_params" << "{";
773 write_params(fs);
774 fs << "}";
775 fs << "weights" << weights;
776 fs << "means" << means;
777
778 size_t i, n = covs.size();
779
780 fs << "covs" << "[";
781 for( i = 0; i < n; i++ )
782 fs << covs[i];
783 fs << "]";
784 }
785
read_params(const FileNode & fn)786 void read_params(const FileNode& fn)
787 {
788 nclusters = (int)fn["nclusters"];
789 String s = (String)fn["cov_mat_type"];
790 covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
791 s == "diagonal" ? COV_MAT_DIAGONAL :
792 s == "generic" ? COV_MAT_GENERIC : -1;
793 CV_Assert(covMatType >= 0);
794 termCrit = readTermCrit(fn);
795 }
796
read(const FileNode & fn)797 void read(const FileNode& fn)
798 {
799 clear();
800 read_params(fn["training_params"]);
801
802 fn["weights"] >> weights;
803 fn["means"] >> means;
804
805 FileNode cfn = fn["covs"];
806 FileNodeIterator cfn_it = cfn.begin();
807 int i, n = (int)cfn.size();
808 covs.resize(n);
809
810 for( i = 0; i < n; i++, ++cfn_it )
811 (*cfn_it) >> covs[i];
812
813 decomposeCovs();
814 computeLogWeightDivDet();
815 }
816
getWeights() const817 Mat getWeights() const { return weights; }
getMeans() const818 Mat getMeans() const { return means; }
getCovs(std::vector<Mat> & _covs) const819 void getCovs(std::vector<Mat>& _covs) const
820 {
821 _covs.resize(covs.size());
822 std::copy(covs.begin(), covs.end(), _covs.begin());
823 }
824
825 // all inner matrices have type CV_64FC1
826 Mat trainSamples;
827 Mat trainProbs;
828 Mat trainLogLikelihoods;
829 Mat trainLabels;
830
831 Mat weights;
832 Mat means;
833 std::vector<Mat> covs;
834
835 std::vector<Mat> covsEigenValues;
836 std::vector<Mat> covsRotateMats;
837 std::vector<Mat> invCovsEigenValues;
838 Mat logWeightDivDet;
839 };
840
create()841 Ptr<EM> EM::create()
842 {
843 return makePtr<EMImpl>();
844 }
845
846 }
847 } // namespace cv
848
849 /* End of file. */
850