1 /*
2  * Copyright (C) 2008 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /* ---- includes ----------------------------------------------------------- */
18 
19 #include "b_TensorEm/Int32Mat.h"
20 #include "b_TensorEm/Functions.h"
21 #include "b_BasicEm/Math.h"
22 #include "b_BasicEm/Functions.h"
23 #include "b_BasicEm/Memory.h"
24 
25 /* ------------------------------------------------------------------------- */
26 
27 /* ========================================================================= */
28 /*                                                                           */
29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */
30 /*                                                                           */
31 /* ========================================================================= */
32 
33 /* ------------------------------------------------------------------------- */
34 
bts_Int32Mat_reduceToNBits(int32 * ptrA,uint32 sizeA,int32 * bbpPtrA,uint32 nBitsA)35 void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
36 {
37 	int32 shiftL;
38 
39 	/* find max element */
40 	int32 maxL = 0;
41 	int32* ptrL = ptrA;
42 	int32 iL = sizeA;
43 	while( iL-- )
44 	{
45 		int32 xL = *ptrL++;
46 		if( xL < 0 ) xL = -xL;
47 		if( xL > maxL ) maxL = xL;
48 	}
49 
50 	/* determine shift */
51 	shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;
52 
53 	if( shiftL > 0 )
54 	{
55 		ptrL = ptrA;
56 		iL = sizeA;
57 		while( iL-- )
58 		{
59 			*ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
60 			ptrL++;
61 		}
62 
63 		*bbpPtrA -= shiftL;
64 	}
65 }
66 
67 /* ------------------------------------------------------------------------- */
68 
69 /* ========================================================================= */
70 /*                                                                           */
71 /* ---- \ghd{ constructor / destructor } ----------------------------------- */
72 /*                                                                           */
73 /* ========================================================================= */
74 
75 /* ------------------------------------------------------------------------- */
76 
bts_Int32Mat_init(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA)77 void bts_Int32Mat_init( struct bbs_Context* cpA,
78 					    struct bts_Int32Mat* ptrA )
79 {
80 	ptrA->widthE = 0;
81 	bbs_Int32Arr_init( cpA, &ptrA->arrE );
82 }
83 
84 /* ------------------------------------------------------------------------- */
85 
bts_Int32Mat_exit(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA)86 void bts_Int32Mat_exit( struct bbs_Context* cpA,
87 					    struct bts_Int32Mat* ptrA )
88 {
89 	ptrA->widthE = 0;
90 	bbs_Int32Arr_exit( cpA, &ptrA->arrE );
91 }
92 /* ------------------------------------------------------------------------- */
93 
94 /* ========================================================================= */
95 /*                                                                           */
96 /* ---- \ghd{ operators } -------------------------------------------------- */
97 /*                                                                           */
98 /* ========================================================================= */
99 
100 /* ------------------------------------------------------------------------- */
101 
102 /* ========================================================================= */
103 /*                                                                           */
104 /* ---- \ghd{ query functions } -------------------------------------------- */
105 /*                                                                           */
106 /* ========================================================================= */
107 
108 /* ------------------------------------------------------------------------- */
109 
110 /* ========================================================================= */
111 /*                                                                           */
112 /* ---- \ghd{ modify functions } ------------------------------------------- */
113 /*                                                                           */
114 /* ========================================================================= */
115 
116 /* ------------------------------------------------------------------------- */
117 
bts_Int32Mat_create(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,int32 widthA,struct bbs_MemSeg * mspA)118 void bts_Int32Mat_create( struct bbs_Context* cpA,
119 						  struct bts_Int32Mat* ptrA,
120 						  int32 widthA,
121 				          struct bbs_MemSeg* mspA )
122 {
123 	if( bbs_Context_error( cpA ) ) return;
124 	bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
125 	ptrA->widthE = widthA;
126 }
127 
128 /* ------------------------------------------------------------------------- */
129 
bts_Int32Mat_copy(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,const struct bts_Int32Mat * srcPtrA)130 void bts_Int32Mat_copy( struct bbs_Context* cpA,
131 					    struct bts_Int32Mat* ptrA,
132 						const struct bts_Int32Mat* srcPtrA )
133 {
134 	if( ptrA->widthE != srcPtrA->widthE )
135 	{
136 		bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
137 			       "size mismatch" );
138 		return;
139 	}
140 
141 	bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
142 }
143 
144 /* ------------------------------------------------------------------------- */
145 
146 /* ========================================================================= */
147 /*                                                                           */
148 /* ---- \ghd{ I/O } -------------------------------------------------------- */
149 /*                                                                           */
150 /* ========================================================================= */
151 
152 /* ------------------------------------------------------------------------- */
153 
bts_Int32Mat_memSize(struct bbs_Context * cpA,const struct bts_Int32Mat * ptrA)154 uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
155 							 const struct bts_Int32Mat *ptrA )
156 {
157 	return  bbs_SIZEOF16( uint32 )
158 		  + bbs_SIZEOF16( uint32 ) /* version */
159 		  + bbs_SIZEOF16( ptrA->widthE )
160 		  + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
161 }
162 
163 /* ------------------------------------------------------------------------- */
164 
bts_Int32Mat_memWrite(struct bbs_Context * cpA,const struct bts_Int32Mat * ptrA,uint16 * memPtrA)165 uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
166 							  const struct bts_Int32Mat* ptrA,
167 							  uint16* memPtrA )
168 {
169 	uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
170 	memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
171 	memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
172 	memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
173 	memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
174 	return memSizeL;
175 }
176 
177 /* ------------------------------------------------------------------------- */
178 
bts_Int32Mat_memRead(struct bbs_Context * cpA,struct bts_Int32Mat * ptrA,const uint16 * memPtrA,struct bbs_MemSeg * mspA)179 uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
180 							 struct bts_Int32Mat* ptrA,
181 							 const uint16* memPtrA,
182 				             struct bbs_MemSeg* mspA )
183 {
184 	uint32 memSizeL, versionL;
185 	if( bbs_Context_error( cpA ) ) return 0;
186 	memPtrA += bbs_memRead32( &memSizeL, memPtrA );
187 	memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
188 	memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
189 	memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );
190 
191 	if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
192 	{
193 		bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
194                   "size mismatch" );
195 	}
196 	return memSizeL;
197 }
198 
199 /* ------------------------------------------------------------------------- */
200 
201 /* ========================================================================= */
202 /*                                                                           */
203 /* ---- \ghd{ exec functions } --------------------------------------------- */
204 /*                                                                           */
205 /* ========================================================================= */
206 
207 /* ------------------------------------------------------------------------- */
208 
bts_Int32Mat_solve(struct bbs_Context * cpA,const int32 * matA,int32 matWidthA,const int32 * inVecA,int32 * outVecA,int32 bbpA,int32 * tmpMatA,int32 * tmpVecA)209 flag bts_Int32Mat_solve( struct bbs_Context* cpA,
210 						 const int32* matA,
211 						 int32 matWidthA,
212 						 const int32* inVecA,
213 						 int32* outVecA,
214 						 int32 bbpA,
215 						 int32* tmpMatA,
216 						 int32* tmpVecA )
217 {
218 	bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );
219 
220 	return bts_Int32Mat_solve2( cpA,
221 		                        tmpMatA,
222 								matWidthA,
223 								inVecA,
224 								outVecA,
225 								bbpA,
226 								tmpVecA );
227 }
228 
229 /* ------------------------------------------------------------------------- */
230 
bts_Int32Mat_solve2(struct bbs_Context * cpA,int32 * matA,int32 matWidthA,const int32 * inVecA,int32 * outVecA,int32 bbpA,int32 * tmpVecA)231 flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
232 						  int32* matA,
233 						  int32 matWidthA,
234 						  const int32* inVecA,
235 						  int32* outVecA,
236 						  int32 bbpA,
237 						  int32* tmpVecA )
238 {
239 	int32 sizeL = matWidthA;
240 	int32 bbpL = bbpA;
241 	int32 iL, jL, kL;
242 	int32 iPivL;
243 	int32 jPivL;
244 
245 	int32* vecL      = outVecA;
246 	int32* matL      = matA;
247 	int32* checkArrL = tmpVecA;
248 
249 	for( iL = 0; iL < sizeL; iL++ )
250 	{
251 		checkArrL[ iL ] = 0;
252 	}
253 
254 	bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );
255 
256 	iPivL = 0;
257 
258 	for( kL = 0; kL < sizeL; kL++ )
259 	{
260 		/* find pivot */
261 		int32 maxAbsL = 0;
262 		int32* pivRowL;
263 
264 		int32 bbp_pivRowL, bbp_vecL, shiftL;
265 
266 		jPivL = -1;
267 		for( iL = 0; iL < sizeL; iL++ )
268 		{
269 			if( checkArrL[ iL ] != 1 )
270 			{
271 				int32* rowL = matL + ( iL * sizeL );
272 				for( jL = 0; jL < sizeL; jL++ )
273 				{
274 					if( checkArrL[ jL ] == 0 )
275 					{
276 						int32 absElemL = rowL[ jL ];
277 						if( absElemL < 0 ) absElemL = -absElemL;
278 						if( maxAbsL < absElemL )
279 						{
280 							maxAbsL = absElemL;
281 							iPivL = iL;
282 							jPivL = jL;
283 						}
284 					}
285 					else if( checkArrL[ jL ] > 1 )
286 					{
287 						return FALSE;
288 					}
289 				}
290 			}
291 		}
292 
293 		/* successfull ? */
294 		if( jPivL < 0 )
295 		{
296 			return FALSE;
297 		}
298 
299 		checkArrL[ jPivL ]++;
300 
301 		/* exchange rows to put pivot on diagonal, if neccessary */
302 		if( iPivL != jPivL )
303 		{
304 			int32* row1PtrL = matL + ( iPivL * sizeL );
305 			int32* row2PtrL = matL + ( jPivL * sizeL );
306 			for( jL = 0; jL < sizeL; jL++ )
307 			{
308 				int32 tmpL = *row1PtrL;
309 				*row1PtrL++ = *row2PtrL;
310 				*row2PtrL++ = tmpL;
311 			}
312 
313 			{
314 				int32 tmpL = vecL[ jPivL ];
315 				vecL[ jPivL ] = vecL[ iPivL ];
316 				vecL[ iPivL ] = tmpL;
317 			}
318 		}
319 		/* now index jPivL specifies pivot row and maximum element */
320 
321 
322 		/**	Overflow protection: only if the highest bit of the largest matrix element is set,
323 		 *	we need to shift the whole matrix and the right side vector 1 bit to the right,
324 		 *	to make sure there can be no overflow when the pivot row gets subtracted from the
325 		 *	other rows.
326 		 *	Getting that close to overflow is a rare event, so this shift will happen only
327 		 *	occasionally, or not at all.
328 		 */
329 		if( maxAbsL & 1073741824 )  /*( 1 << 30 )*/
330 		{
331 			/* right shift matrix by 1 */
332 			int32 iL = sizeL * sizeL;
333 			int32* ptrL = matL;
334 			while( iL-- )
335 			{
336 				*ptrL = ( *ptrL + 1 ) >> 1;
337 				ptrL++;
338 			}
339 
340 			/* right shift right side vector by 1 */
341 			iL = sizeL;
342 			ptrL = vecL;
343 			while( iL-- )
344 			{
345 				*ptrL = ( *ptrL + 1 ) >> 1;
346 				ptrL++;
347 			}
348 
349 			/* decrement bbpL */
350 			bbpL--;
351 		}
352 
353 
354 		/* reduce elements of pivot row to 15 bit */
355 		pivRowL = matL + jPivL * sizeL;
356 		bbp_pivRowL = bbpL;
357 		bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );
358 
359 		/* scale pivot row such that maximum equals 1 */
360 		{
361 			int32 maxL = pivRowL[ jPivL ];
362 			int32 bbp_maxL = bbp_pivRowL;
363 			int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/
364 
365 			for( jL = 0; jL < sizeL; jL++ )
366 			{
367 				pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
368 			}
369 			bbp_pivRowL = 15;
370 
371 			/* set to 1 to avoid computational errors */
372 			pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL;
373 
374 			shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );
375 
376 			vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
377 			bbp_vecL = bbpL + shiftL - bbp_maxL;
378 
379 			bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
380 		}
381 
382 		/* subtract pivot row from all other rows */
383 		for( iL = 0; iL < sizeL; iL++ )
384 		{
385 			if( iL != jPivL )
386 			{
387 				int32* rowPtrL = matL + iL * sizeL;
388 
389 				int32 tmpL = *( rowPtrL + jPivL );
390 				int32 bbp_tmpL = bbpL;
391 				bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );
392 
393 				shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
394 				if( shiftL > 0 )
395 				{
396 					for( jL = 0; jL < sizeL; jL++ )
397 					{
398 						*rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
399 					}
400 				}
401 				else
402 				{
403 					for( jL = 0; jL < sizeL; jL++ )
404 					{
405 						*rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
406 					}
407 				}
408 
409 				shiftL = bbp_tmpL + bbp_vecL - bbpL;
410 				if( shiftL > 0 )
411 				{
412 					vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
413 				}
414 				else
415 				{
416 					vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
417 				}
418 			}
419 		}
420 
421 		/* change bbp of pivot row back to bbpL */
422 		shiftL = bbpL - bbp_pivRowL;
423 		if( shiftL >= 0 )
424 		{
425 			for( jL = 0; jL < sizeL; jL++ )
426 			{
427 				pivRowL[ jL ] <<= shiftL;
428 			}
429 		}
430 		else
431 		{
432 			shiftL = -shiftL;
433 			for( jL = 0; jL < sizeL; jL++ )
434 			{
435 				pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
436 			}
437 		}
438 
439 		shiftL = bbpL - bbp_vecL;
440 		if( shiftL >= 0 )
441 		{
442 			vecL[ jPivL ] <<= shiftL;
443 		}
444 		else
445 		{
446 			shiftL = -shiftL;
447 			vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
448 		}
449 /*
450 if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
451 */
452 	}	/* of kL */
453 
454 	/* in case bbpL has been decreased by the overflow protection, change it back now */
455 	if( bbpA > bbpL )
456 	{
457 		/* find largest element of solution vector */
458 		int32 maxL = 0;
459 		int32 iL, shiftL;
460 		for( iL = 0; iL < sizeL; iL++ )
461 		{
462 			int32 xL = vecL[ iL ];
463 			if( xL < 0 ) xL = -xL;
464 			if( xL > maxL ) maxL = xL;
465 		}
466 
467 		/* check whether we can left shift without overflow */
468 		shiftL = 30 - bts_absIntLog2( maxL );
469 		if( shiftL < ( bbpA - bbpL ) )
470 		{
471 			/*
472 			    bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
473 				"compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
474 			*/
475 
476 			return FALSE;
477 		}
478 
479 		/* shift left */
480 		shiftL = bbpA - bbpL;
481 		for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
482 	}
483 
484 	return TRUE;
485 }
486 
487 /* ------------------------------------------------------------------------- */
488 
489 /* ========================================================================= */
490 
491