1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41 
42 /*
43     Partially based on Yossi Rubner code:
44     =========================================================================
45     emd.c
46 
47     Last update: 3/14/98
48 
49     An implementation of the Earth Movers Distance.
50     Based of the solution for the Transportation problem as described in
51     "Introduction to Mathematical Programming" by F. S. Hillier and
52     G. J. Lieberman, McGraw-Hill, 1990.
53 
54     Copyright (C) 1998 Yossi Rubner
55     Computer Science Department, Stanford University
56     E-Mail: rubner@cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
57     ==========================================================================
58 */
59 #include "precomp.hpp"
60 
61 #define MAX_ITERATIONS 500
62 #define CV_EMD_INF   ((float)1e20)
63 #define CV_EMD_EPS   ((float)1e-5)
64 
65 /* CvNode1D is used for lists, representing 1D sparse array */
66 typedef struct CvNode1D
67 {
68     float val;
69     struct CvNode1D *next;
70 }
71 CvNode1D;
72 
73 /* CvNode2D is used for lists, representing 2D sparse matrix */
74 typedef struct CvNode2D
75 {
76     float val;
77     struct CvNode2D *next[2];  /* next row & next column */
78     int i, j;
79 }
80 CvNode2D;
81 
82 
83 typedef struct CvEMDState
84 {
85     int ssize, dsize;
86 
87     float **cost;
88     CvNode2D *_x;
89     CvNode2D *end_x;
90     CvNode2D *enter_x;
91     char **is_x;
92 
93     CvNode2D **rows_x;
94     CvNode2D **cols_x;
95 
96     CvNode1D *u;
97     CvNode1D *v;
98 
99     int* idx1;
100     int* idx2;
101 
102     /* find_loop buffers */
103     CvNode2D **loop;
104     char *is_used;
105 
106     /* russel buffers */
107     float *s;
108     float *d;
109     float **delta;
110 
111     float weight, max_cost;
112     char *buffer;
113 }
114 CvEMDState;
115 
116 /* static function declaration */
117 static int icvInitEMD( const float *signature1, int size1,
118                        const float *signature2, int size2,
119                        int dims, CvDistanceFunction dist_func, void *user_param,
120                        const float* cost, int cost_step,
121                        CvEMDState * state, float *lower_bound,
122                        cv::AutoBuffer<char>& _buffer );
123 
124 static int icvFindBasicVariables( float **cost, char **is_x,
125                                   CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126 
127 static float icvIsOptimal( float **cost, char **is_x,
128                            CvNode1D * u, CvNode1D * v,
129                            int ssize, int dsize, CvNode2D * enter_x );
130 
131 static void icvRussel( CvEMDState * state );
132 
133 
134 static bool icvNewSolution( CvEMDState * state );
135 static int icvFindLoop( CvEMDState * state );
136 
137 static void icvAddBasicVariable( CvEMDState * state,
138                                  int min_i, int min_j,
139                                  CvNode1D * prev_u_min_i,
140                                  CvNode1D * prev_v_min_j,
141                                  CvNode1D * u_head );
142 
143 static float icvDistL2( const float *x, const float *y, void *user_param );
144 static float icvDistL1( const float *x, const float *y, void *user_param );
145 static float icvDistC( const float *x, const float *y, void *user_param );
146 
147 /* The main function */
cvCalcEMD2(const CvArr * signature_arr1,const CvArr * signature_arr2,int dist_type,CvDistanceFunction dist_func,const CvArr * cost_matrix,CvArr * flow_matrix,float * lower_bound,void * user_param)148 CV_IMPL float cvCalcEMD2( const CvArr* signature_arr1,
149             const CvArr* signature_arr2,
150             int dist_type,
151             CvDistanceFunction dist_func,
152             const CvArr* cost_matrix,
153             CvArr* flow_matrix,
154             float *lower_bound,
155             void *user_param )
156 {
157     cv::AutoBuffer<char> local_buf;
158     CvEMDState state;
159     float emd = 0;
160 
161     memset( &state, 0, sizeof(state));
162 
163     double total_cost = 0;
164     int result = 0;
165     float eps, min_delta;
166     CvNode2D *xp = 0;
167     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
168     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
169     CvMat cost_stub, *cost = &cost_stub;
170     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
171     int dims, size1, size2;
172 
173     signature1 = cvGetMat( signature1, &sign_stub1 );
174     signature2 = cvGetMat( signature2, &sign_stub2 );
175 
176     if( signature1->cols != signature2->cols )
177         CV_Error( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
178 
179     dims = signature1->cols - 1;
180     size1 = signature1->rows;
181     size2 = signature2->rows;
182 
183     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
184         CV_Error( CV_StsUnmatchedFormats, "The array must have equal types" );
185 
186     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
187         CV_Error( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
188 
189     if( flow )
190     {
191         flow = cvGetMat( flow, &flow_stub );
192 
193         if( flow->rows != size1 || flow->cols != size2 )
194             CV_Error( CV_StsUnmatchedSizes,
195             "The flow matrix size does not match to the signatures' sizes" );
196 
197         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
198             CV_Error( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
199     }
200 
201     cost->data.fl = 0;
202     cost->step = 0;
203 
204     if( dist_type < 0 )
205     {
206         if( cost_matrix )
207         {
208             if( dist_func )
209                 CV_Error( CV_StsBadArg,
210                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
211 
212             if( lower_bound )
213                 CV_Error( CV_StsBadArg,
214                 "The lower boundary can not be calculated if the cost matrix is used" );
215 
216             cost = cvGetMat( cost_matrix, &cost_stub );
217             if( cost->rows != size1 || cost->cols != size2 )
218                 CV_Error( CV_StsUnmatchedSizes,
219                 "The cost matrix size does not match to the signatures' sizes" );
220 
221             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
222                 CV_Error( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
223         }
224         else if( !dist_func )
225             CV_Error( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
226     }
227     else
228     {
229         if( dims == 0 )
230             CV_Error( CV_StsBadSize,
231             "Number of dimensions can be 0 only if a user-defined metric is used" );
232         user_param = (void *) (size_t)dims;
233         switch (dist_type)
234         {
235         case CV_DIST_L1:
236             dist_func = icvDistL1;
237             break;
238         case CV_DIST_L2:
239             dist_func = icvDistL2;
240             break;
241         case CV_DIST_C:
242             dist_func = icvDistC;
243             break;
244         default:
245             CV_Error( CV_StsBadFlag, "Bad or unsupported metric type" );
246         }
247     }
248 
249     result = icvInitEMD( signature1->data.fl, size1,
250                         signature2->data.fl, size2,
251                         dims, dist_func, user_param,
252                         cost->data.fl, cost->step,
253                         &state, lower_bound, local_buf );
254 
255     if( result > 0 && lower_bound )
256     {
257         emd = *lower_bound;
258         return emd;
259     }
260 
261     eps = CV_EMD_EPS * state.max_cost;
262 
263     /* if ssize = 1 or dsize = 1 then we are done, else ... */
264     if( state.ssize > 1 && state.dsize > 1 )
265     {
266         int itr;
267 
268         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
269         {
270             /* find basic variables */
271             result = icvFindBasicVariables( state.cost, state.is_x,
272                                             state.u, state.v, state.ssize, state.dsize );
273             if( result < 0 )
274                 break;
275 
276             /* check for optimality */
277             min_delta = icvIsOptimal( state.cost, state.is_x,
278                                       state.u, state.v,
279                                       state.ssize, state.dsize, state.enter_x );
280 
281             if( min_delta == CV_EMD_INF )
282                 CV_Error( CV_StsNoConv, "" );
283 
284             /* if no negative deltamin, we found the optimal solution */
285             if( min_delta >= -eps )
286                 break;
287 
288             /* improve solution */
289             if(!icvNewSolution( &state ))
290                 CV_Error( CV_StsNoConv, "" );
291         }
292     }
293 
294     /* compute the total flow */
295     for( xp = state._x; xp < state.end_x; xp++ )
296     {
297         float val = xp->val;
298         int i = xp->i;
299         int j = xp->j;
300 
301         if( xp == state.enter_x )
302           continue;
303 
304         int ci = state.idx1[i];
305         int cj = state.idx2[j];
306 
307         if( ci >= 0 && cj >= 0 )
308         {
309             total_cost += (double)val * state.cost[i][j];
310             if( flow )
311                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
312         }
313     }
314 
315     emd = (float) (total_cost / state.weight);
316     return emd;
317 }
318 
319 
320 /************************************************************************************\
321 *          initialize structure, allocate buffers and generate initial golution      *
322 \************************************************************************************/
icvInitEMD(const float * signature1,int size1,const float * signature2,int size2,int dims,CvDistanceFunction dist_func,void * user_param,const float * cost,int cost_step,CvEMDState * state,float * lower_bound,cv::AutoBuffer<char> & _buffer)323 static int icvInitEMD( const float* signature1, int size1,
324             const float* signature2, int size2,
325             int dims, CvDistanceFunction dist_func, void* user_param,
326             const float* cost, int cost_step,
327             CvEMDState* state, float* lower_bound,
328             cv::AutoBuffer<char>& _buffer )
329 {
330     float s_sum = 0, d_sum = 0, diff;
331     int i, j;
332     int ssize = 0, dsize = 0;
333     int equal_sums = 1;
334     int buffer_size;
335     float max_cost = 0;
336     char *buffer, *buffer_end;
337 
338     memset( state, 0, sizeof( *state ));
339     assert( cost_step % sizeof(float) == 0 );
340     cost_step /= sizeof(float);
341 
342     /* calculate buffer size */
343     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
344                                    sizeof( char ) +     /* is_x */
345                                    sizeof( float )) +   /* delta matrix */
346         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
347                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
348                            sizeof( CvNode1D ) + /* u & v */
349                            sizeof( float ) + /* s & d */
350                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
351         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
352                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
353 
354     if( buffer_size < (int) (dims * 2 * sizeof( float )))
355     {
356         buffer_size = dims * 2 * sizeof( float );
357     }
358 
359     /* allocate buffers */
360     _buffer.allocate(buffer_size);
361 
362     state->buffer = buffer = _buffer;
363     buffer_end = buffer + buffer_size;
364 
365     state->idx1 = (int*) buffer;
366     buffer += (size1 + 1) * sizeof( int );
367 
368     state->idx2 = (int*) buffer;
369     buffer += (size2 + 1) * sizeof( int );
370 
371     state->s = (float *) buffer;
372     buffer += (size1 + 1) * sizeof( float );
373 
374     state->d = (float *) buffer;
375     buffer += (size2 + 1) * sizeof( float );
376 
377     /* sum up the supply and demand */
378     for( i = 0; i < size1; i++ )
379     {
380         float weight = signature1[i * (dims + 1)];
381 
382         if( weight > 0 )
383         {
384             s_sum += weight;
385             state->s[ssize] = weight;
386             state->idx1[ssize++] = i;
387 
388         }
389         else if( weight < 0 )
390             CV_Error(CV_StsOutOfRange, "");
391     }
392 
393     for( i = 0; i < size2; i++ )
394     {
395         float weight = signature2[i * (dims + 1)];
396 
397         if( weight > 0 )
398         {
399             d_sum += weight;
400             state->d[dsize] = weight;
401             state->idx2[dsize++] = i;
402         }
403         else if( weight < 0 )
404             CV_Error(CV_StsOutOfRange, "");
405     }
406 
407     if( ssize == 0 || dsize == 0 )
408         CV_Error(CV_StsOutOfRange, "");
409 
410     /* if supply different than the demand, add a zero-cost dummy cluster */
411     diff = s_sum - d_sum;
412     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
413     {
414         equal_sums = 0;
415         if( diff < 0 )
416         {
417             state->s[ssize] = -diff;
418             state->idx1[ssize++] = -1;
419         }
420         else
421         {
422             state->d[dsize] = diff;
423             state->idx2[dsize++] = -1;
424         }
425     }
426 
427     state->ssize = ssize;
428     state->dsize = dsize;
429     state->weight = s_sum > d_sum ? s_sum : d_sum;
430 
431     if( lower_bound && equal_sums )     /* check lower bound */
432     {
433         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
434         float lb = 0;
435 
436         float* xs = (float *) buffer;
437         float* xd = xs + dims;
438 
439         memset( xs, 0, dims*sizeof(xs[0]));
440         memset( xd, 0, dims*sizeof(xd[0]));
441 
442         for( j = 0; j < sz1; j += dims + 1 )
443         {
444             float weight = signature1[j];
445             for( i = 0; i < dims; i++ )
446                 xs[i] += signature1[j + i + 1] * weight;
447         }
448 
449         for( j = 0; j < sz2; j += dims + 1 )
450         {
451             float weight = signature2[j];
452             for( i = 0; i < dims; i++ )
453                 xd[i] += signature2[j + i + 1] * weight;
454         }
455 
456         lb = dist_func( xs, xd, user_param ) / state->weight;
457         i = *lower_bound <= lb;
458         *lower_bound = lb;
459         if( i )
460             return 1;
461     }
462 
463     /* assign pointers */
464     state->is_used = (char *) buffer;
465     /* init delta matrix */
466     state->delta = (float **) buffer;
467     buffer += ssize * sizeof( float * );
468 
469     for( i = 0; i < ssize; i++ )
470     {
471         state->delta[i] = (float *) buffer;
472         buffer += dsize * sizeof( float );
473     }
474 
475     state->loop = (CvNode2D **) buffer;
476     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
477 
478     state->_x = state->end_x = (CvNode2D *) buffer;
479     buffer += (ssize + dsize) * sizeof( CvNode2D );
480 
481     /* init cost matrix */
482     state->cost = (float **) buffer;
483     buffer += ssize * sizeof( float * );
484 
485     /* compute the distance matrix */
486     for( i = 0; i < ssize; i++ )
487     {
488         int ci = state->idx1[i];
489 
490         state->cost[i] = (float *) buffer;
491         buffer += dsize * sizeof( float );
492 
493         if( ci >= 0 )
494         {
495             for( j = 0; j < dsize; j++ )
496             {
497                 int cj = state->idx2[j];
498                 if( cj < 0 )
499                     state->cost[i][j] = 0;
500                 else
501                 {
502                     float val;
503                     if( dist_func )
504                     {
505                         val = dist_func( signature1 + ci * (dims + 1) + 1,
506                                          signature2 + cj * (dims + 1) + 1,
507                                          user_param );
508                     }
509                     else
510                     {
511                         assert( cost );
512                         val = cost[cost_step*ci + cj];
513                     }
514                     state->cost[i][j] = val;
515                     if( max_cost < val )
516                         max_cost = val;
517                 }
518             }
519         }
520         else
521         {
522             for( j = 0; j < dsize; j++ )
523                 state->cost[i][j] = 0;
524         }
525     }
526 
527     state->max_cost = max_cost;
528 
529     memset( buffer, 0, buffer_end - buffer );
530 
531     state->rows_x = (CvNode2D **) buffer;
532     buffer += ssize * sizeof( CvNode2D * );
533 
534     state->cols_x = (CvNode2D **) buffer;
535     buffer += dsize * sizeof( CvNode2D * );
536 
537     state->u = (CvNode1D *) buffer;
538     buffer += ssize * sizeof( CvNode1D );
539 
540     state->v = (CvNode1D *) buffer;
541     buffer += dsize * sizeof( CvNode1D );
542 
543     /* init is_x matrix */
544     state->is_x = (char **) buffer;
545     buffer += ssize * sizeof( char * );
546 
547     for( i = 0; i < ssize; i++ )
548     {
549         state->is_x[i] = buffer;
550         buffer += dsize;
551     }
552 
553     assert( buffer <= buffer_end );
554 
555     icvRussel( state );
556 
557     state->enter_x = (state->end_x)++;
558     return 0;
559 }
560 
561 
562 /****************************************************************************************\
563 *                              icvFindBasicVariables                                   *
564 \****************************************************************************************/
icvFindBasicVariables(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize)565 static int icvFindBasicVariables( float **cost, char **is_x,
566                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
567 {
568     int i, j, found;
569     int u_cfound, v_cfound;
570     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
571     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
572 
573     /* initialize the rows list (u) and the columns list (v) */
574     u0_head.next = u;
575     for( i = 0; i < ssize; i++ )
576     {
577         u[i].next = u + i + 1;
578     }
579     u[ssize - 1].next = 0;
580     u1_head.next = 0;
581 
582     v0_head.next = ssize > 1 ? v + 1 : 0;
583     for( i = 1; i < dsize; i++ )
584     {
585         v[i].next = v + i + 1;
586     }
587     v[dsize - 1].next = 0;
588     v1_head.next = 0;
589 
590     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
591        so set v[0]=0 */
592     v[0].val = 0;
593     v1_head.next = v;
594     v1_head.next->next = 0;
595 
596     /* loop until all variables are found */
597     u_cfound = v_cfound = 0;
598     while( u_cfound < ssize || v_cfound < dsize )
599     {
600         found = 0;
601         if( v_cfound < dsize )
602         {
603             /* loop over all marked columns */
604             prev_v = &v1_head;
605 
606             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
607             {
608                 float cur_v_val = cur_v->val;
609 
610                 j = (int)(cur_v - v);
611                 /* find the variables in column j */
612                 prev_u = &u0_head;
613                 for( cur_u = u0_head.next; cur_u != 0; )
614                 {
615                     i = (int)(cur_u - u);
616                     if( is_x[i][j] )
617                     {
618                         /* compute u[i] */
619                         cur_u->val = cost[i][j] - cur_v_val;
620                         /* ...and add it to the marked list */
621                         prev_u->next = cur_u->next;
622                         cur_u->next = u1_head.next;
623                         u1_head.next = cur_u;
624                         cur_u = prev_u->next;
625                     }
626                     else
627                     {
628                         prev_u = cur_u;
629                         cur_u = cur_u->next;
630                     }
631                 }
632                 prev_v->next = cur_v->next;
633                 v_cfound++;
634             }
635         }
636 
637         if( u_cfound < ssize )
638         {
639             /* loop over all marked rows */
640             prev_u = &u1_head;
641             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
642             {
643                 float cur_u_val = cur_u->val;
644                 float *_cost;
645                 char *_is_x;
646 
647                 i = (int)(cur_u - u);
648                 _cost = cost[i];
649                 _is_x = is_x[i];
650                 /* find the variables in rows i */
651                 prev_v = &v0_head;
652                 for( cur_v = v0_head.next; cur_v != 0; )
653                 {
654                     j = (int)(cur_v - v);
655                     if( _is_x[j] )
656                     {
657                         /* compute v[j] */
658                         cur_v->val = _cost[j] - cur_u_val;
659                         /* ...and add it to the marked list */
660                         prev_v->next = cur_v->next;
661                         cur_v->next = v1_head.next;
662                         v1_head.next = cur_v;
663                         cur_v = prev_v->next;
664                     }
665                     else
666                     {
667                         prev_v = cur_v;
668                         cur_v = cur_v->next;
669                     }
670                 }
671                 prev_u->next = cur_u->next;
672                 u_cfound++;
673             }
674         }
675 
676         if( !found )
677             return -1;
678     }
679 
680     return 0;
681 }
682 
683 
684 /****************************************************************************************\
685 *                                   icvIsOptimal                                       *
686 \****************************************************************************************/
687 static float
icvIsOptimal(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize,CvNode2D * enter_x)688 icvIsOptimal( float **cost, char **is_x,
689               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
690 {
691     float delta, min_delta = CV_EMD_INF;
692     int i, j, min_i = 0, min_j = 0;
693 
694     /* find the minimal cij-ui-vj over all i,j */
695     for( i = 0; i < ssize; i++ )
696     {
697         float u_val = u[i].val;
698         float *_cost = cost[i];
699         char *_is_x = is_x[i];
700 
701         for( j = 0; j < dsize; j++ )
702         {
703             if( !_is_x[j] )
704             {
705                 delta = _cost[j] - u_val - v[j].val;
706                 if( min_delta > delta )
707                 {
708                     min_delta = delta;
709                     min_i = i;
710                     min_j = j;
711                 }
712             }
713         }
714     }
715 
716     enter_x->i = min_i;
717     enter_x->j = min_j;
718 
719     return min_delta;
720 }
721 
722 /****************************************************************************************\
723 *                                   icvNewSolution                                     *
724 \****************************************************************************************/
725 static bool
icvNewSolution(CvEMDState * state)726 icvNewSolution( CvEMDState * state )
727 {
728     int i, j;
729     float min_val = CV_EMD_INF;
730     int steps;
731     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
732     CvNode2D *enter_x = state->enter_x;
733     CvNode2D **loop = state->loop;
734 
735     /* enter the new basic variable */
736     i = enter_x->i;
737     j = enter_x->j;
738     state->is_x[i][j] = 1;
739     enter_x->next[0] = state->rows_x[i];
740     enter_x->next[1] = state->cols_x[j];
741     enter_x->val = 0;
742     state->rows_x[i] = enter_x;
743     state->cols_x[j] = enter_x;
744 
745     /* find a chain reaction */
746     steps = icvFindLoop( state );
747 
748     if( steps == 0 )
749         return false;
750 
751     /* find the largest value in the loop */
752     for( i = 1; i < steps; i += 2 )
753     {
754         float temp = loop[i]->val;
755 
756         if( min_val > temp )
757         {
758             leave_x = loop[i];
759             min_val = temp;
760         }
761     }
762 
763     /* update the loop */
764     for( i = 0; i < steps; i += 2 )
765     {
766         float temp0 = loop[i]->val + min_val;
767         float temp1 = loop[i + 1]->val - min_val;
768 
769         loop[i]->val = temp0;
770         loop[i + 1]->val = temp1;
771     }
772 
773     /* remove the leaving basic variable */
774     i = leave_x->i;
775     j = leave_x->j;
776     state->is_x[i][j] = 0;
777 
778     head.next[0] = state->rows_x[i];
779     cur_x = &head;
780     while( (next_x = cur_x->next[0]) != leave_x )
781     {
782         cur_x = next_x;
783         assert( cur_x );
784     }
785     cur_x->next[0] = next_x->next[0];
786     state->rows_x[i] = head.next[0];
787 
788     head.next[1] = state->cols_x[j];
789     cur_x = &head;
790     while( (next_x = cur_x->next[1]) != leave_x )
791     {
792         cur_x = next_x;
793         assert( cur_x );
794     }
795     cur_x->next[1] = next_x->next[1];
796     state->cols_x[j] = head.next[1];
797 
798     /* set enter_x to be the new empty slot */
799     state->enter_x = leave_x;
800 
801     return true;
802 }
803 
804 
805 
806 /****************************************************************************************\
807 *                                    icvFindLoop                                       *
808 \****************************************************************************************/
809 static int
icvFindLoop(CvEMDState * state)810 icvFindLoop( CvEMDState * state )
811 {
812     int i, steps = 1;
813     CvNode2D *new_x;
814     CvNode2D **loop = state->loop;
815     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
816     char *is_used = state->is_used;
817 
818     memset( is_used, 0, state->ssize + state->dsize );
819 
820     new_x = loop[0] = enter_x;
821     is_used[enter_x - _x] = 1;
822     steps = 1;
823 
824     do
825     {
826         if( (steps & 1) == 1 )
827         {
828             /* find an unused x in the row */
829             new_x = state->rows_x[new_x->i];
830             while( new_x != 0 && is_used[new_x - _x] )
831                 new_x = new_x->next[0];
832         }
833         else
834         {
835             /* find an unused x in the column, or the entering x */
836             new_x = state->cols_x[new_x->j];
837             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
838                 new_x = new_x->next[1];
839             if( new_x == enter_x )
840                 break;
841         }
842 
843         if( new_x != 0 )        /* found the next x */
844         {
845             /* add x to the loop */
846             loop[steps++] = new_x;
847             is_used[new_x - _x] = 1;
848         }
849         else                    /* didn't find the next x */
850         {
851             /* backtrack */
852             do
853             {
854                 i = steps & 1;
855                 new_x = loop[steps - 1];
856                 do
857                 {
858                     new_x = new_x->next[i];
859                 }
860                 while( new_x != 0 && is_used[new_x - _x] );
861 
862                 if( new_x == 0 )
863                 {
864                     is_used[loop[--steps] - _x] = 0;
865                 }
866             }
867             while( new_x == 0 && steps > 0 );
868 
869             is_used[loop[steps - 1] - _x] = 0;
870             loop[steps - 1] = new_x;
871             is_used[new_x - _x] = 1;
872         }
873     }
874     while( steps > 0 );
875 
876     return steps;
877 }
878 
879 
880 
881 /****************************************************************************************\
882 *                                        icvRussel                                     *
883 \****************************************************************************************/
884 static void
icvRussel(CvEMDState * state)885 icvRussel( CvEMDState * state )
886 {
887     int i, j, min_i = -1, min_j = -1;
888     float min_delta, diff;
889     CvNode1D u_head, *cur_u, *prev_u;
890     CvNode1D v_head, *cur_v, *prev_v;
891     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
892     CvNode1D *u = state->u, *v = state->v;
893     int ssize = state->ssize, dsize = state->dsize;
894     float eps = CV_EMD_EPS * state->max_cost;
895     float **cost = state->cost;
896     float **delta = state->delta;
897 
898     /* initialize the rows list (ur), and the columns list (vr) */
899     u_head.next = u;
900     for( i = 0; i < ssize; i++ )
901     {
902         u[i].next = u + i + 1;
903     }
904     u[ssize - 1].next = 0;
905 
906     v_head.next = v;
907     for( i = 0; i < dsize; i++ )
908     {
909         v[i].val = -CV_EMD_INF;
910         v[i].next = v + i + 1;
911     }
912     v[dsize - 1].next = 0;
913 
914     /* find the maximum row and column values (ur[i] and vr[j]) */
915     for( i = 0; i < ssize; i++ )
916     {
917         float u_val = -CV_EMD_INF;
918         float *cost_row = cost[i];
919 
920         for( j = 0; j < dsize; j++ )
921         {
922             float temp = cost_row[j];
923 
924             if( u_val < temp )
925                 u_val = temp;
926             if( v[j].val < temp )
927                 v[j].val = temp;
928         }
929         u[i].val = u_val;
930     }
931 
932     /* compute the delta matrix */
933     for( i = 0; i < ssize; i++ )
934     {
935         float u_val = u[i].val;
936         float *delta_row = delta[i];
937         float *cost_row = cost[i];
938 
939         for( j = 0; j < dsize; j++ )
940         {
941             delta_row[j] = cost_row[j] - u_val - v[j].val;
942         }
943     }
944 
945     /* find the basic variables */
946     do
947     {
948         /* find the smallest delta[i][j] */
949         min_i = -1;
950         min_delta = CV_EMD_INF;
951         prev_u = &u_head;
952         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
953         {
954             i = (int)(cur_u - u);
955             float *delta_row = delta[i];
956 
957             prev_v = &v_head;
958             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
959             {
960                 j = (int)(cur_v - v);
961                 if( min_delta > delta_row[j] )
962                 {
963                     min_delta = delta_row[j];
964                     min_i = i;
965                     min_j = j;
966                     prev_u_min_i = prev_u;
967                     prev_v_min_j = prev_v;
968                 }
969                 prev_v = cur_v;
970             }
971             prev_u = cur_u;
972         }
973 
974         if( min_i < 0 )
975             break;
976 
977         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
978         remember = prev_u_min_i->next;
979         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
980 
981         /* update the necessary delta[][] */
982         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
983         {
984             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
985             {
986                 j = (int)(cur_v - v);
987                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
988                 {
989                     float max_val = -CV_EMD_INF;
990 
991                     /* find the new maximum value in the column */
992                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
993                     {
994                         float temp = cost[cur_u - u][j];
995 
996                         if( max_val < temp )
997                             max_val = temp;
998                     }
999 
1000                     /* if needed, adjust the relevant delta[*][j] */
1001                     diff = max_val - cur_v->val;
1002                     cur_v->val = max_val;
1003                     if( fabs( diff ) < eps )
1004                     {
1005                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1006                             delta[cur_u - u][j] += diff;
1007                     }
1008                 }
1009             }
1010         }
1011         else                    /* column min_j was deleted */
1012         {
1013             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1014             {
1015                 i = (int)(cur_u - u);
1016                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
1017                 {
1018                     float max_val = -CV_EMD_INF;
1019 
1020                     /* find the new maximum value in the row */
1021                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1022                     {
1023                         float temp = cost[i][cur_v - v];
1024 
1025                         if( max_val < temp )
1026                             max_val = temp;
1027                     }
1028 
1029                     /* if needed, adjust the relevant delta[i][*] */
1030                     diff = max_val - cur_u->val;
1031                     cur_u->val = max_val;
1032 
1033                     if( fabs( diff ) < eps )
1034                     {
1035                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1036                             delta[i][cur_v - v] += diff;
1037                     }
1038                 }
1039             }
1040         }
1041     }
1042     while( u_head.next != 0 || v_head.next != 0 );
1043 }
1044 
1045 
1046 
1047 /****************************************************************************************\
1048 *                                   icvAddBasicVariable                                *
1049 \****************************************************************************************/
1050 static void
icvAddBasicVariable(CvEMDState * state,int min_i,int min_j,CvNode1D * prev_u_min_i,CvNode1D * prev_v_min_j,CvNode1D * u_head)1051 icvAddBasicVariable( CvEMDState * state,
1052                      int min_i, int min_j,
1053                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1054 {
1055     float temp;
1056     CvNode2D *end_x = state->end_x;
1057 
1058     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1059     {                           /* supply exhausted */
1060         temp = state->s[min_i];
1061         state->s[min_i] = 0;
1062         state->d[min_j] -= temp;
1063     }
1064     else                        /* demand exhausted */
1065     {
1066         temp = state->d[min_j];
1067         state->d[min_j] = 0;
1068         state->s[min_i] -= temp;
1069     }
1070 
1071     /* x(min_i,min_j) is a basic variable */
1072     state->is_x[min_i][min_j] = 1;
1073 
1074     end_x->val = temp;
1075     end_x->i = min_i;
1076     end_x->j = min_j;
1077     end_x->next[0] = state->rows_x[min_i];
1078     end_x->next[1] = state->cols_x[min_j];
1079     state->rows_x[min_i] = end_x;
1080     state->cols_x[min_j] = end_x;
1081     state->end_x = end_x + 1;
1082 
1083     /* delete supply row only if the empty, and if not last row */
1084     if( state->s[min_i] == 0 && u_head->next->next != 0 )
1085         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
1086     else
1087         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
1088 }
1089 
1090 
1091 /****************************************************************************************\
1092 *                                  standard  metrics                                     *
1093 \****************************************************************************************/
1094 static float
icvDistL1(const float * x,const float * y,void * user_param)1095 icvDistL1( const float *x, const float *y, void *user_param )
1096 {
1097     int i, dims = (int)(size_t)user_param;
1098     double s = 0;
1099 
1100     for( i = 0; i < dims; i++ )
1101     {
1102         double t = x[i] - y[i];
1103 
1104         s += fabs( t );
1105     }
1106     return (float)s;
1107 }
1108 
1109 static float
icvDistL2(const float * x,const float * y,void * user_param)1110 icvDistL2( const float *x, const float *y, void *user_param )
1111 {
1112     int i, dims = (int)(size_t)user_param;
1113     double s = 0;
1114 
1115     for( i = 0; i < dims; i++ )
1116     {
1117         double t = x[i] - y[i];
1118 
1119         s += t * t;
1120     }
1121     return cvSqrt( (float)s );
1122 }
1123 
1124 static float
icvDistC(const float * x,const float * y,void * user_param)1125 icvDistC( const float *x, const float *y, void *user_param )
1126 {
1127     int i, dims = (int)(size_t)user_param;
1128     double s = 0;
1129 
1130     for( i = 0; i < dims; i++ )
1131     {
1132         double t = fabs( x[i] - y[i] );
1133 
1134         if( s < t )
1135             s = t;
1136     }
1137     return (float)s;
1138 }
1139 
1140 
EMD(InputArray _signature1,InputArray _signature2,int distType,InputArray _cost,float * lowerBound,OutputArray _flow)1141 float cv::EMD( InputArray _signature1, InputArray _signature2,
1142                int distType, InputArray _cost,
1143                float* lowerBound, OutputArray _flow )
1144 {
1145     Mat signature1 = _signature1.getMat(), signature2 = _signature2.getMat();
1146     Mat cost = _cost.getMat(), flow;
1147 
1148     CvMat _csignature1 = signature1;
1149     CvMat _csignature2 = signature2;
1150     CvMat _ccost = cost, _cflow;
1151     if( _flow.needed() )
1152     {
1153         _flow.create(signature1.rows, signature2.rows, CV_32F);
1154         flow = _flow.getMat();
1155         flow = Scalar::all(0);
1156         _cflow = flow;
1157     }
1158 
1159     return cvCalcEMD2( &_csignature1, &_csignature2, distType, 0, cost.empty() ? 0 : &_ccost,
1160                        _flow.needed() ? &_cflow : 0, lowerBound, 0 );
1161 }
1162 
1163 /* End of file. */
1164