1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41 
42 /*
43     Partially based on Yossi Rubner code:
44     =========================================================================
45     emd.c
46 
47     Last update: 3/14/98
48 
49     An implementation of the Earth Movers Distance.
50     Based of the solution for the Transportation problem as described in
51     "Introduction to Mathematical Programming" by F. S. Hillier and
52     G. J. Lieberman, McGraw-Hill, 1990.
53 
54     Copyright (C) 1998 Yossi Rubner
55     Computer Science Department, Stanford University
56     E-Mail: rubner@cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
57     ==========================================================================
58 */
59 #include "_cv.h"
60 
61 #define MAX_ITERATIONS 500
62 #define CV_EMD_INF   ((float)1e20)
63 #define CV_EMD_EPS   ((float)1e-5)
64 
65 /* CvNode1D is used for lists, representing 1D sparse array */
66 typedef struct CvNode1D
67 {
68     float val;
69     struct CvNode1D *next;
70 }
71 CvNode1D;
72 
73 /* CvNode2D is used for lists, representing 2D sparse matrix */
74 typedef struct CvNode2D
75 {
76     float val;
77     struct CvNode2D *next[2];  /* next row & next column */
78     int i, j;
79 }
80 CvNode2D;
81 
82 
83 typedef struct CvEMDState
84 {
85     int ssize, dsize;
86 
87     float **cost;
88     CvNode2D *_x;
89     CvNode2D *end_x;
90     CvNode2D *enter_x;
91     char **is_x;
92 
93     CvNode2D **rows_x;
94     CvNode2D **cols_x;
95 
96     CvNode1D *u;
97     CvNode1D *v;
98 
99     int* idx1;
100     int* idx2;
101 
102     /* find_loop buffers */
103     CvNode2D **loop;
104     char *is_used;
105 
106     /* russel buffers */
107     float *s;
108     float *d;
109     float **delta;
110 
111     float weight, max_cost;
112     char *buffer;
113 }
114 CvEMDState;
115 
116 /* static function declaration */
117 static CvStatus icvInitEMD( const float *signature1, int size1,
118                             const float *signature2, int size2,
119                             int dims, CvDistanceFunction dist_func, void *user_param,
120                             const float* cost, int cost_step,
121                             CvEMDState * state, float *lower_bound,
122                             char *local_buffer, int local_buffer_size );
123 
124 static CvStatus icvFindBasicVariables( float **cost, char **is_x,
125                                        CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126 
127 static float icvIsOptimal( float **cost, char **is_x,
128                            CvNode1D * u, CvNode1D * v,
129                            int ssize, int dsize, CvNode2D * enter_x );
130 
131 static void icvRussel( CvEMDState * state );
132 
133 
134 static CvStatus icvNewSolution( CvEMDState * state );
135 static int icvFindLoop( CvEMDState * state );
136 
137 static void icvAddBasicVariable( CvEMDState * state,
138                                  int min_i, int min_j,
139                                  CvNode1D * prev_u_min_i,
140                                  CvNode1D * prev_v_min_j,
141                                  CvNode1D * u_head );
142 
143 static float icvDistL2( const float *x, const float *y, void *user_param );
144 static float icvDistL1( const float *x, const float *y, void *user_param );
145 static float icvDistC( const float *x, const float *y, void *user_param );
146 
147 /* The main function */
148 CV_IMPL float
cvCalcEMD2(const CvArr * signature_arr1,const CvArr * signature_arr2,int dist_type,CvDistanceFunction dist_func,const CvArr * cost_matrix,CvArr * flow_matrix,float * lower_bound,void * user_param)149 cvCalcEMD2( const CvArr* signature_arr1,
150             const CvArr* signature_arr2,
151             int dist_type,
152             CvDistanceFunction dist_func,
153             const CvArr* cost_matrix,
154             CvArr* flow_matrix,
155             float *lower_bound,
156             void *user_param )
157 {
158     char local_buffer[16384];
159     char *local_buffer_ptr = (char *)cvAlignPtr(local_buffer,16);
160     CvEMDState state;
161     float emd = 0;
162 
163     CV_FUNCNAME( "cvCalcEMD2" );
164 
165     memset( &state, 0, sizeof(state));
166 
167     __BEGIN__;
168 
169     double total_cost = 0;
170     CvStatus result = CV_NO_ERR;
171     float eps, min_delta;
172     CvNode2D *xp = 0;
173     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
174     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
175     CvMat cost_stub, *cost = &cost_stub;
176     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
177     int dims, size1, size2;
178 
179     CV_CALL( signature1 = cvGetMat( signature1, &sign_stub1 ));
180     CV_CALL( signature2 = cvGetMat( signature2, &sign_stub2 ));
181 
182     if( signature1->cols != signature2->cols )
183         CV_ERROR( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
184 
185     dims = signature1->cols - 1;
186     size1 = signature1->rows;
187     size2 = signature2->rows;
188 
189     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
190         CV_ERROR( CV_StsUnmatchedFormats, "The array must have equal types" );
191 
192     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
193         CV_ERROR( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
194 
195     if( flow )
196     {
197         CV_CALL( flow = cvGetMat( flow, &flow_stub ));
198 
199         if( flow->rows != size1 || flow->cols != size2 )
200             CV_ERROR( CV_StsUnmatchedSizes,
201             "The flow matrix size does not match to the signatures' sizes" );
202 
203         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
204             CV_ERROR( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
205     }
206 
207     cost->data.fl = 0;
208     cost->step = 0;
209 
210     if( dist_type < 0 )
211     {
212         if( cost_matrix )
213         {
214             if( dist_func )
215                 CV_ERROR( CV_StsBadArg,
216                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
217 
218             if( lower_bound )
219                 CV_ERROR( CV_StsBadArg,
220                 "The lower boundary can not be calculated if the cost matrix is used" );
221 
222             CV_CALL( cost = cvGetMat( cost_matrix, &cost_stub ));
223             if( cost->rows != size1 || cost->cols != size2 )
224                 CV_ERROR( CV_StsUnmatchedSizes,
225                 "The cost matrix size does not match to the signatures' sizes" );
226 
227             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
228                 CV_ERROR( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
229         }
230         else if( !dist_func )
231             CV_ERROR( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
232     }
233     else
234     {
235         if( dims == 0 )
236             CV_ERROR( CV_StsBadSize,
237             "Number of dimensions can be 0 only if a user-defined metric is used" );
238         user_param = (void *) (size_t)dims;
239         switch (dist_type)
240         {
241         case CV_DIST_L1:
242             dist_func = icvDistL1;
243             break;
244         case CV_DIST_L2:
245             dist_func = icvDistL2;
246             break;
247         case CV_DIST_C:
248             dist_func = icvDistC;
249             break;
250         default:
251             CV_ERROR( CV_StsBadFlag, "Bad or unsupported metric type" );
252         }
253     }
254 
255     IPPI_CALL( result = icvInitEMD( signature1->data.fl, size1,
256                                     signature2->data.fl, size2,
257                                     dims, dist_func, user_param,
258                                     cost->data.fl, cost->step,
259                                     &state, lower_bound, local_buffer_ptr,
260                                     sizeof( local_buffer ) - 16 ));
261 
262     if( result > 0 && lower_bound )
263     {
264         emd = *lower_bound;
265         EXIT;
266     }
267 
268     eps = CV_EMD_EPS * state.max_cost;
269 
270     /* if ssize = 1 or dsize = 1 then we are done, else ... */
271     if( state.ssize > 1 && state.dsize > 1 )
272     {
273         int itr;
274 
275         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
276         {
277             /* find basic variables */
278             result = icvFindBasicVariables( state.cost, state.is_x,
279                                             state.u, state.v, state.ssize, state.dsize );
280             if( result < 0 )
281                 break;
282 
283             /* check for optimality */
284             min_delta = icvIsOptimal( state.cost, state.is_x,
285                                       state.u, state.v,
286                                       state.ssize, state.dsize, state.enter_x );
287 
288             if( min_delta == CV_EMD_INF )
289             {
290                 CV_ERROR( CV_StsNoConv, "" );
291             }
292 
293             /* if no negative deltamin, we found the optimal solution */
294             if( min_delta >= -eps )
295                 break;
296 
297             /* improve solution */
298             IPPI_CALL( icvNewSolution( &state ));
299         }
300     }
301 
302     /* compute the total flow */
303     for( xp = state._x; xp < state.end_x; xp++ )
304     {
305         float val = xp->val;
306         int i = xp->i;
307         int j = xp->j;
308         int ci = state.idx1[i];
309         int cj = state.idx2[j];
310 
311         if( xp != state.enter_x && ci >= 0 && cj >= 0 )
312         {
313             total_cost += (double)val * state.cost[i][j];
314             if( flow )
315                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
316         }
317     }
318 
319     emd = (float) (total_cost / state.weight);
320 
321     __END__;
322 
323     if( state.buffer && state.buffer != local_buffer_ptr )
324         cvFree( &state.buffer );
325 
326     return emd;
327 }
328 
329 
330 /************************************************************************************\
331 *          initialize structure, allocate buffers and generate initial golution      *
332 \************************************************************************************/
333 static CvStatus
icvInitEMD(const float * signature1,int size1,const float * signature2,int size2,int dims,CvDistanceFunction dist_func,void * user_param,const float * cost,int cost_step,CvEMDState * state,float * lower_bound,char * local_buffer,int local_buffer_size)334 icvInitEMD( const float* signature1, int size1,
335             const float* signature2, int size2,
336             int dims, CvDistanceFunction dist_func, void* user_param,
337             const float* cost, int cost_step,
338             CvEMDState* state, float* lower_bound,
339             char* local_buffer, int local_buffer_size )
340 {
341     float s_sum = 0, d_sum = 0, diff;
342     int i, j;
343     int ssize = 0, dsize = 0;
344     int equal_sums = 1;
345     int buffer_size;
346     float max_cost = 0;
347     char *buffer, *buffer_end;
348 
349     memset( state, 0, sizeof( *state ));
350     assert( cost_step % sizeof(float) == 0 );
351     cost_step /= sizeof(float);
352 
353     /* calculate buffer size */
354     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
355                                    sizeof( char ) +     /* is_x */
356                                    sizeof( float )) +   /* delta matrix */
357         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
358                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
359                            sizeof( CvNode1D ) + /* u & v */
360                            sizeof( float ) + /* s & d */
361                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
362         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
363                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
364 
365     if( buffer_size < (int) (dims * 2 * sizeof( float )))
366     {
367         buffer_size = dims * 2 * sizeof( float );
368     }
369 
370     /* allocate buffers */
371     if( local_buffer != 0 && local_buffer_size >= buffer_size )
372     {
373         buffer = local_buffer;
374     }
375     else
376     {
377         buffer = (char*)cvAlloc( buffer_size );
378         if( !buffer )
379             return CV_OUTOFMEM_ERR;
380     }
381 
382     state->buffer = buffer;
383     buffer_end = buffer + buffer_size;
384 
385     state->idx1 = (int*) buffer;
386     buffer += (size1 + 1) * sizeof( int );
387 
388     state->idx2 = (int*) buffer;
389     buffer += (size2 + 1) * sizeof( int );
390 
391     state->s = (float *) buffer;
392     buffer += (size1 + 1) * sizeof( float );
393 
394     state->d = (float *) buffer;
395     buffer += (size2 + 1) * sizeof( float );
396 
397     /* sum up the supply and demand */
398     for( i = 0; i < size1; i++ )
399     {
400         float weight = signature1[i * (dims + 1)];
401 
402         if( weight > 0 )
403         {
404             s_sum += weight;
405             state->s[ssize] = weight;
406             state->idx1[ssize++] = i;
407 
408         }
409         else if( weight < 0 )
410             return CV_BADRANGE_ERR;
411     }
412 
413     for( i = 0; i < size2; i++ )
414     {
415         float weight = signature2[i * (dims + 1)];
416 
417         if( weight > 0 )
418         {
419             d_sum += weight;
420             state->d[dsize] = weight;
421             state->idx2[dsize++] = i;
422         }
423         else if( weight < 0 )
424             return CV_BADRANGE_ERR;
425     }
426 
427     if( ssize == 0 || dsize == 0 )
428         return CV_BADRANGE_ERR;
429 
430     /* if supply different than the demand, add a zero-cost dummy cluster */
431     diff = s_sum - d_sum;
432     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
433     {
434         equal_sums = 0;
435         if( diff < 0 )
436         {
437             state->s[ssize] = -diff;
438             state->idx1[ssize++] = -1;
439         }
440         else
441         {
442             state->d[dsize] = diff;
443             state->idx2[dsize++] = -1;
444         }
445     }
446 
447     state->ssize = ssize;
448     state->dsize = dsize;
449     state->weight = s_sum > d_sum ? s_sum : d_sum;
450 
451     if( lower_bound && equal_sums )     /* check lower bound */
452     {
453         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
454         float lb = 0;
455 
456         float* xs = (float *) buffer;
457         float* xd = xs + dims;
458 
459         memset( xs, 0, dims*sizeof(xs[0]));
460         memset( xd, 0, dims*sizeof(xd[0]));
461 
462         for( j = 0; j < sz1; j += dims + 1 )
463         {
464             float weight = signature1[j];
465             for( i = 0; i < dims; i++ )
466                 xs[i] += signature1[j + i + 1] * weight;
467         }
468 
469         for( j = 0; j < sz2; j += dims + 1 )
470         {
471             float weight = signature2[j];
472             for( i = 0; i < dims; i++ )
473                 xd[i] += signature2[j + i + 1] * weight;
474         }
475 
476         lb = dist_func( xs, xd, user_param ) / state->weight;
477         i = *lower_bound <= lb;
478         *lower_bound = lb;
479         if( i )
480             return ( CvStatus ) 1;
481     }
482 
483     /* assign pointers */
484     state->is_used = (char *) buffer;
485     /* init delta matrix */
486     state->delta = (float **) buffer;
487     buffer += ssize * sizeof( float * );
488 
489     for( i = 0; i < ssize; i++ )
490     {
491         state->delta[i] = (float *) buffer;
492         buffer += dsize * sizeof( float );
493     }
494 
495     state->loop = (CvNode2D **) buffer;
496     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
497 
498     state->_x = state->end_x = (CvNode2D *) buffer;
499     buffer += (ssize + dsize) * sizeof( CvNode2D );
500 
501     /* init cost matrix */
502     state->cost = (float **) buffer;
503     buffer += ssize * sizeof( float * );
504 
505     /* compute the distance matrix */
506     for( i = 0; i < ssize; i++ )
507     {
508         int ci = state->idx1[i];
509 
510         state->cost[i] = (float *) buffer;
511         buffer += dsize * sizeof( float );
512 
513         if( ci >= 0 )
514         {
515             for( j = 0; j < dsize; j++ )
516             {
517                 int cj = state->idx2[j];
518                 if( cj < 0 )
519                     state->cost[i][j] = 0;
520                 else
521                 {
522                     float val;
523                     if( dist_func )
524                     {
525                         val = dist_func( signature1 + ci * (dims + 1) + 1,
526                                          signature2 + cj * (dims + 1) + 1,
527                                          user_param );
528                     }
529                     else
530                     {
531                         assert( cost );
532                         val = cost[cost_step*ci + cj];
533                     }
534                     state->cost[i][j] = val;
535                     if( max_cost < val )
536                         max_cost = val;
537                 }
538             }
539         }
540         else
541         {
542             for( j = 0; j < dsize; j++ )
543                 state->cost[i][j] = 0;
544         }
545     }
546 
547     state->max_cost = max_cost;
548 
549     memset( buffer, 0, buffer_end - buffer );
550 
551     state->rows_x = (CvNode2D **) buffer;
552     buffer += ssize * sizeof( CvNode2D * );
553 
554     state->cols_x = (CvNode2D **) buffer;
555     buffer += dsize * sizeof( CvNode2D * );
556 
557     state->u = (CvNode1D *) buffer;
558     buffer += ssize * sizeof( CvNode1D );
559 
560     state->v = (CvNode1D *) buffer;
561     buffer += dsize * sizeof( CvNode1D );
562 
563     /* init is_x matrix */
564     state->is_x = (char **) buffer;
565     buffer += ssize * sizeof( char * );
566 
567     for( i = 0; i < ssize; i++ )
568     {
569         state->is_x[i] = buffer;
570         buffer += dsize;
571     }
572 
573     assert( buffer <= buffer_end );
574 
575     icvRussel( state );
576 
577     state->enter_x = (state->end_x)++;
578     return CV_NO_ERR;
579 }
580 
581 
582 /****************************************************************************************\
583 *                              icvFindBasicVariables                                   *
584 \****************************************************************************************/
585 static CvStatus
icvFindBasicVariables(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize)586 icvFindBasicVariables( float **cost, char **is_x,
587                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
588 {
589     int i, j, found;
590     int u_cfound, v_cfound;
591     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
592     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
593 
594     /* initialize the rows list (u) and the columns list (v) */
595     u0_head.next = u;
596     for( i = 0; i < ssize; i++ )
597     {
598         u[i].next = u + i + 1;
599     }
600     u[ssize - 1].next = 0;
601     u1_head.next = 0;
602 
603     v0_head.next = ssize > 1 ? v + 1 : 0;
604     for( i = 1; i < dsize; i++ )
605     {
606         v[i].next = v + i + 1;
607     }
608     v[dsize - 1].next = 0;
609     v1_head.next = 0;
610 
611     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
612        so set v[0]=0 */
613     v[0].val = 0;
614     v1_head.next = v;
615     v1_head.next->next = 0;
616 
617     /* loop until all variables are found */
618     u_cfound = v_cfound = 0;
619     while( u_cfound < ssize || v_cfound < dsize )
620     {
621         found = 0;
622         if( v_cfound < dsize )
623         {
624             /* loop over all marked columns */
625             prev_v = &v1_head;
626 
627             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
628             {
629                 float cur_v_val = cur_v->val;
630 
631                 j = (int)(cur_v - v);
632                 /* find the variables in column j */
633                 prev_u = &u0_head;
634                 for( cur_u = u0_head.next; cur_u != 0; )
635                 {
636                     i = (int)(cur_u - u);
637                     if( is_x[i][j] )
638                     {
639                         /* compute u[i] */
640                         cur_u->val = cost[i][j] - cur_v_val;
641                         /* ...and add it to the marked list */
642                         prev_u->next = cur_u->next;
643                         cur_u->next = u1_head.next;
644                         u1_head.next = cur_u;
645                         cur_u = prev_u->next;
646                     }
647                     else
648                     {
649                         prev_u = cur_u;
650                         cur_u = cur_u->next;
651                     }
652                 }
653                 prev_v->next = cur_v->next;
654                 v_cfound++;
655             }
656         }
657 
658         if( u_cfound < ssize )
659         {
660             /* loop over all marked rows */
661             prev_u = &u1_head;
662             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
663             {
664                 float cur_u_val = cur_u->val;
665                 float *_cost;
666                 char *_is_x;
667 
668                 i = (int)(cur_u - u);
669                 _cost = cost[i];
670                 _is_x = is_x[i];
671                 /* find the variables in rows i */
672                 prev_v = &v0_head;
673                 for( cur_v = v0_head.next; cur_v != 0; )
674                 {
675                     j = (int)(cur_v - v);
676                     if( _is_x[j] )
677                     {
678                         /* compute v[j] */
679                         cur_v->val = _cost[j] - cur_u_val;
680                         /* ...and add it to the marked list */
681                         prev_v->next = cur_v->next;
682                         cur_v->next = v1_head.next;
683                         v1_head.next = cur_v;
684                         cur_v = prev_v->next;
685                     }
686                     else
687                     {
688                         prev_v = cur_v;
689                         cur_v = cur_v->next;
690                     }
691                 }
692                 prev_u->next = cur_u->next;
693                 u_cfound++;
694             }
695         }
696 
697         if( !found )
698         {
699             return CV_NOTDEFINED_ERR;
700         }
701     }
702 
703     return CV_NO_ERR;
704 }
705 
706 
707 /****************************************************************************************\
708 *                                   icvIsOptimal                                       *
709 \****************************************************************************************/
710 static float
icvIsOptimal(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize,CvNode2D * enter_x)711 icvIsOptimal( float **cost, char **is_x,
712               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
713 {
714     float delta, min_delta = CV_EMD_INF;
715     int i, j, min_i = 0, min_j = 0;
716 
717     /* find the minimal cij-ui-vj over all i,j */
718     for( i = 0; i < ssize; i++ )
719     {
720         float u_val = u[i].val;
721         float *_cost = cost[i];
722         char *_is_x = is_x[i];
723 
724         for( j = 0; j < dsize; j++ )
725         {
726             if( !_is_x[j] )
727             {
728                 delta = _cost[j] - u_val - v[j].val;
729                 if( min_delta > delta )
730                 {
731                     min_delta = delta;
732                     min_i = i;
733                     min_j = j;
734                 }
735             }
736         }
737     }
738 
739     enter_x->i = min_i;
740     enter_x->j = min_j;
741 
742     return min_delta;
743 }
744 
745 /****************************************************************************************\
746 *                                   icvNewSolution                                     *
747 \****************************************************************************************/
748 static CvStatus
icvNewSolution(CvEMDState * state)749 icvNewSolution( CvEMDState * state )
750 {
751     int i, j;
752     float min_val = CV_EMD_INF;
753     int steps;
754     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
755     CvNode2D *enter_x = state->enter_x;
756     CvNode2D **loop = state->loop;
757 
758     /* enter the new basic variable */
759     i = enter_x->i;
760     j = enter_x->j;
761     state->is_x[i][j] = 1;
762     enter_x->next[0] = state->rows_x[i];
763     enter_x->next[1] = state->cols_x[j];
764     enter_x->val = 0;
765     state->rows_x[i] = enter_x;
766     state->cols_x[j] = enter_x;
767 
768     /* find a chain reaction */
769     steps = icvFindLoop( state );
770 
771     if( steps == 0 )
772         return CV_NOTDEFINED_ERR;
773 
774     /* find the largest value in the loop */
775     for( i = 1; i < steps; i += 2 )
776     {
777         float temp = loop[i]->val;
778 
779         if( min_val > temp )
780         {
781             leave_x = loop[i];
782             min_val = temp;
783         }
784     }
785 
786     /* update the loop */
787     for( i = 0; i < steps; i += 2 )
788     {
789         float temp0 = loop[i]->val + min_val;
790         float temp1 = loop[i + 1]->val - min_val;
791 
792         loop[i]->val = temp0;
793         loop[i + 1]->val = temp1;
794     }
795 
796     /* remove the leaving basic variable */
797     i = leave_x->i;
798     j = leave_x->j;
799     state->is_x[i][j] = 0;
800 
801     head.next[0] = state->rows_x[i];
802     cur_x = &head;
803     while( (next_x = cur_x->next[0]) != leave_x )
804     {
805         cur_x = next_x;
806         assert( cur_x );
807     }
808     cur_x->next[0] = next_x->next[0];
809     state->rows_x[i] = head.next[0];
810 
811     head.next[1] = state->cols_x[j];
812     cur_x = &head;
813     while( (next_x = cur_x->next[1]) != leave_x )
814     {
815         cur_x = next_x;
816         assert( cur_x );
817     }
818     cur_x->next[1] = next_x->next[1];
819     state->cols_x[j] = head.next[1];
820 
821     /* set enter_x to be the new empty slot */
822     state->enter_x = leave_x;
823 
824     return CV_NO_ERR;
825 }
826 
827 
828 
829 /****************************************************************************************\
830 *                                    icvFindLoop                                       *
831 \****************************************************************************************/
832 static int
icvFindLoop(CvEMDState * state)833 icvFindLoop( CvEMDState * state )
834 {
835     int i, steps = 1;
836     CvNode2D *new_x;
837     CvNode2D **loop = state->loop;
838     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
839     char *is_used = state->is_used;
840 
841     memset( is_used, 0, state->ssize + state->dsize );
842 
843     new_x = loop[0] = enter_x;
844     is_used[enter_x - _x] = 1;
845     steps = 1;
846 
847     do
848     {
849         if( (steps & 1) == 1 )
850         {
851             /* find an unused x in the row */
852             new_x = state->rows_x[new_x->i];
853             while( new_x != 0 && is_used[new_x - _x] )
854                 new_x = new_x->next[0];
855         }
856         else
857         {
858             /* find an unused x in the column, or the entering x */
859             new_x = state->cols_x[new_x->j];
860             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
861                 new_x = new_x->next[1];
862             if( new_x == enter_x )
863                 break;
864         }
865 
866         if( new_x != 0 )        /* found the next x */
867         {
868             /* add x to the loop */
869             loop[steps++] = new_x;
870             is_used[new_x - _x] = 1;
871         }
872         else                    /* didn't find the next x */
873         {
874             /* backtrack */
875             do
876             {
877                 i = steps & 1;
878                 new_x = loop[steps - 1];
879                 do
880                 {
881                     new_x = new_x->next[i];
882                 }
883                 while( new_x != 0 && is_used[new_x - _x] );
884 
885                 if( new_x == 0 )
886                 {
887                     is_used[loop[--steps] - _x] = 0;
888                 }
889             }
890             while( new_x == 0 && steps > 0 );
891 
892             is_used[loop[steps - 1] - _x] = 0;
893             loop[steps - 1] = new_x;
894             is_used[new_x - _x] = 1;
895         }
896     }
897     while( steps > 0 );
898 
899     return steps;
900 }
901 
902 
903 
904 /****************************************************************************************\
905 *                                        icvRussel                                     *
906 \****************************************************************************************/
907 static void
icvRussel(CvEMDState * state)908 icvRussel( CvEMDState * state )
909 {
910     int i, j, min_i = -1, min_j = -1;
911     float min_delta, diff;
912     CvNode1D u_head, *cur_u, *prev_u;
913     CvNode1D v_head, *cur_v, *prev_v;
914     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
915     CvNode1D *u = state->u, *v = state->v;
916     int ssize = state->ssize, dsize = state->dsize;
917     float eps = CV_EMD_EPS * state->max_cost;
918     float **cost = state->cost;
919     float **delta = state->delta;
920 
921     /* initialize the rows list (ur), and the columns list (vr) */
922     u_head.next = u;
923     for( i = 0; i < ssize; i++ )
924     {
925         u[i].next = u + i + 1;
926     }
927     u[ssize - 1].next = 0;
928 
929     v_head.next = v;
930     for( i = 0; i < dsize; i++ )
931     {
932         v[i].val = -CV_EMD_INF;
933         v[i].next = v + i + 1;
934     }
935     v[dsize - 1].next = 0;
936 
937     /* find the maximum row and column values (ur[i] and vr[j]) */
938     for( i = 0; i < ssize; i++ )
939     {
940         float u_val = -CV_EMD_INF;
941         float *cost_row = cost[i];
942 
943         for( j = 0; j < dsize; j++ )
944         {
945             float temp = cost_row[j];
946 
947             if( u_val < temp )
948                 u_val = temp;
949             if( v[j].val < temp )
950                 v[j].val = temp;
951         }
952         u[i].val = u_val;
953     }
954 
955     /* compute the delta matrix */
956     for( i = 0; i < ssize; i++ )
957     {
958         float u_val = u[i].val;
959         float *delta_row = delta[i];
960         float *cost_row = cost[i];
961 
962         for( j = 0; j < dsize; j++ )
963         {
964             delta_row[j] = cost_row[j] - u_val - v[j].val;
965         }
966     }
967 
968     /* find the basic variables */
969     do
970     {
971         /* find the smallest delta[i][j] */
972         min_i = -1;
973         min_delta = CV_EMD_INF;
974         prev_u = &u_head;
975         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
976         {
977             i = (int)(cur_u - u);
978             float *delta_row = delta[i];
979 
980             prev_v = &v_head;
981             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
982             {
983                 j = (int)(cur_v - v);
984                 if( min_delta > delta_row[j] )
985                 {
986                     min_delta = delta_row[j];
987                     min_i = i;
988                     min_j = j;
989                     prev_u_min_i = prev_u;
990                     prev_v_min_j = prev_v;
991                 }
992                 prev_v = cur_v;
993             }
994             prev_u = cur_u;
995         }
996 
997         if( min_i < 0 )
998             break;
999 
1000         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
1001         remember = prev_u_min_i->next;
1002         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
1003 
1004         /* update the necessary delta[][] */
1005         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
1006         {
1007             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1008             {
1009                 j = (int)(cur_v - v);
1010                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
1011                 {
1012                     float max_val = -CV_EMD_INF;
1013 
1014                     /* find the new maximum value in the column */
1015                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1016                     {
1017                         float temp = cost[cur_u - u][j];
1018 
1019                         if( max_val < temp )
1020                             max_val = temp;
1021                     }
1022 
1023                     /* if needed, adjust the relevant delta[*][j] */
1024                     diff = max_val - cur_v->val;
1025                     cur_v->val = max_val;
1026                     if( fabs( diff ) < eps )
1027                     {
1028                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1029                             delta[cur_u - u][j] += diff;
1030                     }
1031                 }
1032             }
1033         }
1034         else                    /* column min_j was deleted */
1035         {
1036             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1037             {
1038                 i = (int)(cur_u - u);
1039                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
1040                 {
1041                     float max_val = -CV_EMD_INF;
1042 
1043                     /* find the new maximum value in the row */
1044                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1045                     {
1046                         float temp = cost[i][cur_v - v];
1047 
1048                         if( max_val < temp )
1049                             max_val = temp;
1050                     }
1051 
1052                     /* if needed, adjust the relevant delta[i][*] */
1053                     diff = max_val - cur_u->val;
1054                     cur_u->val = max_val;
1055 
1056                     if( fabs( diff ) < eps )
1057                     {
1058                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1059                             delta[i][cur_v - v] += diff;
1060                     }
1061                 }
1062             }
1063         }
1064     }
1065     while( u_head.next != 0 || v_head.next != 0 );
1066 }
1067 
1068 
1069 
1070 /****************************************************************************************\
1071 *                                   icvAddBasicVariable                                *
1072 \****************************************************************************************/
1073 static void
icvAddBasicVariable(CvEMDState * state,int min_i,int min_j,CvNode1D * prev_u_min_i,CvNode1D * prev_v_min_j,CvNode1D * u_head)1074 icvAddBasicVariable( CvEMDState * state,
1075                      int min_i, int min_j,
1076                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1077 {
1078     float temp;
1079     CvNode2D *end_x = state->end_x;
1080 
1081     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1082     {                           /* supply exhausted */
1083         temp = state->s[min_i];
1084         state->s[min_i] = 0;
1085         state->d[min_j] -= temp;
1086     }
1087     else                        /* demand exhausted */
1088     {
1089         temp = state->d[min_j];
1090         state->d[min_j] = 0;
1091         state->s[min_i] -= temp;
1092     }
1093 
1094     /* x(min_i,min_j) is a basic variable */
1095     state->is_x[min_i][min_j] = 1;
1096 
1097     end_x->val = temp;
1098     end_x->i = min_i;
1099     end_x->j = min_j;
1100     end_x->next[0] = state->rows_x[min_i];
1101     end_x->next[1] = state->cols_x[min_j];
1102     state->rows_x[min_i] = end_x;
1103     state->cols_x[min_j] = end_x;
1104     state->end_x = end_x + 1;
1105 
1106     /* delete supply row only if the empty, and if not last row */
1107     if( state->s[min_i] == 0 && u_head->next->next != 0 )
1108         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
1109     else
1110         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
1111 }
1112 
1113 
1114 /****************************************************************************************\
1115 *                                  standard  metrics                                     *
1116 \****************************************************************************************/
1117 static float
icvDistL1(const float * x,const float * y,void * user_param)1118 icvDistL1( const float *x, const float *y, void *user_param )
1119 {
1120     int i, dims = (int)(size_t)user_param;
1121     double s = 0;
1122 
1123     for( i = 0; i < dims; i++ )
1124     {
1125         double t = x[i] - y[i];
1126 
1127         s += fabs( t );
1128     }
1129     return (float)s;
1130 }
1131 
1132 static float
icvDistL2(const float * x,const float * y,void * user_param)1133 icvDistL2( const float *x, const float *y, void *user_param )
1134 {
1135     int i, dims = (int)(size_t)user_param;
1136     double s = 0;
1137 
1138     for( i = 0; i < dims; i++ )
1139     {
1140         double t = x[i] - y[i];
1141 
1142         s += t * t;
1143     }
1144     return cvSqrt( (float)s );
1145 }
1146 
1147 static float
icvDistC(const float * x,const float * y,void * user_param)1148 icvDistC( const float *x, const float *y, void *user_param )
1149 {
1150     int i, dims = (int)(size_t)user_param;
1151     double s = 0;
1152 
1153     for( i = 0; i < dims; i++ )
1154     {
1155         double t = fabs( x[i] - y[i] );
1156 
1157         if( s < t )
1158             s = t;
1159     }
1160     return (float)s;
1161 }
1162 
1163 /* End of file. */
1164 
1165