1 /*
2   Name:     imath.c
3   Purpose:  Arbitrary precision integer arithmetic routines.
4   Author:   M. J. Fromberger <http://spinning-yarns.org/michael/>
5 
6   Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
7 
8   Permission is hereby granted, free of charge, to any person obtaining a copy
9   of this software and associated documentation files (the "Software"), to deal
10   in the Software without restriction, including without limitation the rights
11   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12   copies of the Software, and to permit persons to whom the Software is
13   furnished to do so, subject to the following conditions:
14 
15   The above copyright notice and this permission notice shall be included in
16   all copies or substantial portions of the Software.
17 
18   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
21   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24   SOFTWARE.
25  */
26 
27 #include "imath.h"
28 
29 #if DEBUG
30 #include <stdio.h>
31 #endif
32 
33 #include <stdlib.h>
34 #include <string.h>
35 #include <ctype.h>
36 
37 #include <assert.h>
38 
39 #if DEBUG
40 #define STATIC /* public */
41 #else
42 #define STATIC static
43 #endif
44 
45 const mp_result MP_OK     = 0;  /* no error, all is well  */
46 const mp_result MP_FALSE  = 0;  /* boolean false          */
47 const mp_result MP_TRUE   = -1; /* boolean true           */
48 const mp_result MP_MEMORY = -2; /* out of memory          */
49 const mp_result MP_RANGE  = -3; /* argument out of range  */
50 const mp_result MP_UNDEF  = -4; /* result undefined       */
51 const mp_result MP_TRUNC  = -5; /* output truncated       */
52 const mp_result MP_BADARG = -6; /* invalid null argument  */
53 const mp_result MP_MINERR = -6;
54 
55 const mp_sign   MP_NEG  = 1;    /* value is strictly negative */
56 const mp_sign   MP_ZPOS = 0;    /* value is non-negative      */
57 
58 STATIC const char *s_unknown_err = "unknown result code";
59 STATIC const char *s_error_msg[] = {
60   "error code 0",
61   "boolean true",
62   "out of memory",
63   "argument out of range",
64   "result undefined",
65   "output truncated",
66   "invalid argument",
67   NULL
68 };
69 
70 /* Argument checking macros
71    Use CHECK() where a return value is required; NRCHECK() elsewhere */
72 #define CHECK(TEST)   assert(TEST)
73 #define NRCHECK(TEST) assert(TEST)
74 
75 /* The ith entry of this table gives the value of log_i(2).
76 
77    An integer value n requires ceil(log_i(n)) digits to be represented
78    in base i.  Since it is easy to compute lg(n), by counting bits, we
79    can compute log_i(n) = lg(n) * log_i(2).
80 
81    The use of this table eliminates a dependency upon linkage against
82    the standard math libraries.
83 
84    If MP_MAX_RADIX is increased, this table should be expanded too.
85  */
86 STATIC const double s_log2[] = {
87    0.000000000, 0.000000000, 1.000000000, 0.630929754, 	/* (D)(D) 2  3 */
88    0.500000000, 0.430676558, 0.386852807, 0.356207187, 	/*  4  5  6  7 */
89    0.333333333, 0.315464877, 0.301029996, 0.289064826, 	/*  8  9 10 11 */
90    0.278942946, 0.270238154, 0.262649535, 0.255958025, 	/* 12 13 14 15 */
91    0.250000000, 0.244650542, 0.239812467, 0.235408913, 	/* 16 17 18 19 */
92    0.231378213, 0.227670249, 0.224243824, 0.221064729, 	/* 20 21 22 23 */
93    0.218104292, 0.215338279, 0.212746054, 0.210309918, 	/* 24 25 26 27 */
94    0.208014598, 0.205846832, 0.203795047, 0.201849087, 	/* 28 29 30 31 */
95    0.200000000, 0.198239863, 0.196561632, 0.194959022, 	/* 32 33 34 35 */
96    0.193426404,                                         /* 36          */
97 };
98 
99 
100 
101 /* Return the number of digits needed to represent a static value */
102 #define MP_VALUE_DIGITS(V) \
103 ((sizeof(V)+(sizeof(mp_digit)-1))/sizeof(mp_digit))
104 
105 /* Round precision P to nearest word boundary */
106 #define ROUND_PREC(P) ((mp_size)(2*(((P)+1)/2)))
107 
108 /* Set array P of S digits to zero */
109 #define ZERO(P, S) \
110 do{ \
111   mp_size i__ = (S) * sizeof(mp_digit); \
112   mp_digit *p__ = (P); \
113   memset(p__, 0, i__); \
114 } while(0)
115 
116 /* Copy S digits from array P to array Q */
117 #define COPY(P, Q, S) \
118 do{ \
119   mp_size i__ = (S) * sizeof(mp_digit); \
120   mp_digit *p__ = (P), *q__ = (Q); \
121   memcpy(q__, p__, i__); \
122 } while(0)
123 
124 /* Reverse N elements of type T in array A */
125 #define REV(T, A, N) \
126 do{ \
127   T *u_ = (A), *v_ = u_ + (N) - 1; \
128   while (u_ < v_) { \
129     T xch = *u_; \
130     *u_++ = *v_; \
131     *v_-- = xch; \
132   } \
133 } while(0)
134 
135 #define CLAMP(Z) \
136 do{ \
137   mp_int z_ = (Z); \
138   mp_size uz_ = MP_USED(z_); \
139   mp_digit *dz_ = MP_DIGITS(z_) + uz_ -1; \
140   while (uz_ > 1 && (*dz_-- == 0)) \
141     --uz_; \
142   MP_USED(z_) = uz_; \
143 } while(0)
144 
145 /* Select min/max.  Do not provide expressions for which multiple
146    evaluation would be problematic, e.g. x++ */
147 #define MIN(A, B) ((B)<(A)?(B):(A))
148 #define MAX(A, B) ((B)>(A)?(B):(A))
149 
150 /* Exchange lvalues A and B of type T, e.g.
151    SWAP(int, x, y) where x and y are variables of type int. */
152 #define SWAP(T, A, B) \
153 do{ \
154   T t_ = (A); \
155   A = (B); \
156   B = t_; \
157 } while(0)
158 
159 /* Used to set up and access simple temp stacks within functions. */
160 #define DECLARE_TEMP(N) \
161   mpz_t temp[(N)]; \
162   int last__ = 0
163 #define CLEANUP_TEMP() \
164  CLEANUP: \
165   while (--last__ >= 0) \
166     mp_int_clear(TEMP(last__))
167 #define TEMP(K) (temp + (K))
168 #define LAST_TEMP() TEMP(last__)
169 #define SETUP(E) \
170 do{ \
171   if ((res = (E)) != MP_OK) \
172     goto CLEANUP; \
173   ++(last__); \
174 } while(0)
175 
176 /* Compare value to zero. */
177 #define CMPZ(Z) \
178 (((Z)->used==1&&(Z)->digits[0]==0)?0:((Z)->sign==MP_NEG)?-1:1)
179 
180 /* Multiply X by Y into Z, ignoring signs.  Requires that Z have
181    enough storage preallocated to hold the result. */
182 #define UMUL(X, Y, Z) \
183 do{ \
184   mp_size ua_ = MP_USED(X), ub_ = MP_USED(Y); \
185   mp_size o_ = ua_ + ub_; \
186   ZERO(MP_DIGITS(Z), o_); \
187   (void) s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_); \
188   MP_USED(Z) = o_; \
189   CLAMP(Z); \
190 } while(0)
191 
192 /* Square X into Z.  Requires that Z have enough storage to hold the
193    result. */
194 #define USQR(X, Z) \
195 do{ \
196   mp_size ua_ = MP_USED(X), o_ = ua_ + ua_; \
197   ZERO(MP_DIGITS(Z), o_); \
198   (void) s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_); \
199   MP_USED(Z) = o_; \
200   CLAMP(Z); \
201 } while(0)
202 
203 #define UPPER_HALF(W)           ((mp_word)((W) >> MP_DIGIT_BIT))
204 #define LOWER_HALF(W)           ((mp_digit)(W))
205 #define HIGH_BIT_SET(W)         ((W) >> (MP_WORD_BIT - 1))
206 #define ADD_WILL_OVERFLOW(W, V) ((MP_WORD_MAX - (V)) < (W))
207 
208 
209 
210 /* Default number of digits allocated to a new mp_int */
211 #if IMATH_TEST
212 mp_size default_precision = MP_DEFAULT_PREC;
213 #else
214 STATIC const mp_size default_precision = MP_DEFAULT_PREC;
215 #endif
216 
217 /* Minimum number of digits to invoke recursive multiply */
218 #if IMATH_TEST
219 mp_size multiply_threshold = MP_MULT_THRESH;
220 #else
221 STATIC const mp_size multiply_threshold = MP_MULT_THRESH;
222 #endif
223 
224 /* Allocate a buffer of (at least) num digits, or return
225    NULL if that couldn't be done.  */
226 STATIC mp_digit *s_alloc(mp_size num);
227 
228 /* Release a buffer of digits allocated by s_alloc(). */
229 STATIC void s_free(void *ptr);
230 
231 /* Insure that z has at least min digits allocated, resizing if
232    necessary.  Returns true if successful, false if out of memory. */
233 STATIC int  s_pad(mp_int z, mp_size min);
234 
235 /* Fill in a "fake" mp_int on the stack with a given value */
236 STATIC void      s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
237 STATIC void      s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
238 
239 /* Compare two runs of digits of given length, returns <0, 0, >0 */
240 STATIC int       s_cdig(mp_digit *da, mp_digit *db, mp_size len);
241 
242 /* Pack the unsigned digits of v into array t */
243 STATIC int       s_uvpack(mp_usmall v, mp_digit t[]);
244 
245 /* Compare magnitudes of a and b, returns <0, 0, >0 */
246 STATIC int       s_ucmp(mp_int a, mp_int b);
247 
248 /* Compare magnitudes of a and v, returns <0, 0, >0 */
249 STATIC int       s_vcmp(mp_int a, mp_small v);
250 STATIC int       s_uvcmp(mp_int a, mp_usmall uv);
251 
252 /* Unsigned magnitude addition; assumes dc is big enough.
253    Carry out is returned (no memory allocated). */
254 STATIC mp_digit  s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
255 		        mp_size size_a, mp_size size_b);
256 
257 /* Unsigned magnitude subtraction.  Assumes dc is big enough. */
258 STATIC void      s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
259 		        mp_size size_a, mp_size size_b);
260 
261 /* Unsigned recursive multiplication.  Assumes dc is big enough. */
262 STATIC int       s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
263 			mp_size size_a, mp_size size_b);
264 
265 /* Unsigned magnitude multiplication.  Assumes dc is big enough. */
266 STATIC void      s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
267 			mp_size size_a, mp_size size_b);
268 
269 /* Unsigned recursive squaring.  Assumes dc is big enough. */
270 STATIC int       s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
271 
272 /* Unsigned magnitude squaring.  Assumes dc is big enough. */
273 STATIC void      s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
274 
275 /* Single digit addition.  Assumes a is big enough. */
276 STATIC void      s_dadd(mp_int a, mp_digit b);
277 
278 /* Single digit multiplication.  Assumes a is big enough. */
279 STATIC void      s_dmul(mp_int a, mp_digit b);
280 
281 /* Single digit multiplication on buffers; assumes dc is big enough. */
282 STATIC void      s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc,
283 			 mp_size size_a);
284 
285 /* Single digit division.  Replaces a with the quotient,
286    returns the remainder.  */
287 STATIC mp_digit  s_ddiv(mp_int a, mp_digit b);
288 
289 /* Quick division by a power of 2, replaces z (no allocation) */
290 STATIC void      s_qdiv(mp_int z, mp_size p2);
291 
292 /* Quick remainder by a power of 2, replaces z (no allocation) */
293 STATIC void      s_qmod(mp_int z, mp_size p2);
294 
295 /* Quick multiplication by a power of 2, replaces z.
296    Allocates if necessary; returns false in case this fails. */
297 STATIC int       s_qmul(mp_int z, mp_size p2);
298 
299 /* Quick subtraction from a power of 2, replaces z.
300    Allocates if necessary; returns false in case this fails. */
301 STATIC int       s_qsub(mp_int z, mp_size p2);
302 
303 /* Return maximum k such that 2^k divides z. */
304 STATIC int       s_dp2k(mp_int z);
305 
306 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
307 STATIC int       s_isp2(mp_int z);
308 
309 /* Set z to 2^k.  May allocate; returns false in case this fails. */
310 STATIC int       s_2expt(mp_int z, mp_small k);
311 
312 /* Normalize a and b for division, returns normalization constant */
313 STATIC int       s_norm(mp_int a, mp_int b);
314 
315 /* Compute constant mu for Barrett reduction, given modulus m, result
316    replaces z, m is untouched. */
317 STATIC mp_result s_brmu(mp_int z, mp_int m);
318 
319 /* Reduce a modulo m, using Barrett's algorithm. */
320 STATIC int       s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
321 
322 /* Modular exponentiation, using Barrett reduction */
323 STATIC mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
324 
325 /* Unsigned magnitude division.  Assumes |a| > |b|.  Allocates temporaries;
326    overwrites a with quotient, b with remainder. */
327 STATIC mp_result s_udiv_knuth(mp_int a, mp_int b);
328 
329 /* Compute the number of digits in radix r required to represent the given
330    value.  Does not account for sign flags, terminators, etc. */
331 STATIC int       s_outlen(mp_int z, mp_size r);
332 
333 /* Guess how many digits of precision will be needed to represent a radix r
334    value of the specified number of digits.  Returns a value guaranteed to be
335    no smaller than the actual number required. */
336 STATIC mp_size   s_inlen(int len, mp_size r);
337 
338 /* Convert a character to a digit value in radix r, or
339    -1 if out of range */
340 STATIC int       s_ch2val(char c, int r);
341 
342 /* Convert a digit value to a character */
343 STATIC char      s_val2ch(int v, int caps);
344 
345 /* Take 2's complement of a buffer in place */
346 STATIC void      s_2comp(unsigned char *buf, int len);
347 
348 /* Convert a value to binary, ignoring sign.  On input, *limpos is the bound on
349    how many bytes should be written to buf; on output, *limpos is set to the
350    number of bytes actually written. */
351 STATIC mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
352 
353 #if DEBUG
354 /* Dump a representation of the mp_int to standard output */
355 void      s_print(char *tag, mp_int z);
356 void      s_print_buf(char *tag, mp_digit *buf, mp_size num);
357 #endif
358 
mp_int_init(mp_int z)359 mp_result mp_int_init(mp_int z)
360 {
361   if (z == NULL)
362     return MP_BADARG;
363 
364   z->single = 0;
365   z->digits = &(z->single);
366   z->alloc  = 1;
367   z->used   = 1;
368   z->sign   = MP_ZPOS;
369 
370   return MP_OK;
371 }
372 
mp_int_alloc(void)373 mp_int    mp_int_alloc(void)
374 {
375   mp_int out = malloc(sizeof(mpz_t));
376 
377   if (out != NULL)
378     mp_int_init(out);
379 
380   return out;
381 }
382 
mp_int_init_size(mp_int z,mp_size prec)383 mp_result mp_int_init_size(mp_int z, mp_size prec)
384 {
385   CHECK(z != NULL);
386 
387   if (prec == 0)
388     prec = default_precision;
389   else if (prec == 1)
390     return mp_int_init(z);
391   else
392     prec = (mp_size) ROUND_PREC(prec);
393 
394   if ((MP_DIGITS(z) = s_alloc(prec)) == NULL)
395     return MP_MEMORY;
396 
397   z->digits[0] = 0;
398   MP_USED(z) = 1;
399   MP_ALLOC(z) = prec;
400   MP_SIGN(z) = MP_ZPOS;
401 
402   return MP_OK;
403 }
404 
mp_int_init_copy(mp_int z,mp_int old)405 mp_result mp_int_init_copy(mp_int z, mp_int old)
406 {
407   mp_result res;
408   mp_size uold;
409 
410   CHECK(z != NULL && old != NULL);
411 
412   uold = MP_USED(old);
413   if (uold == 1) {
414     mp_int_init(z);
415   }
416   else {
417     mp_size target = MAX(uold, default_precision);
418 
419     if ((res = mp_int_init_size(z, target)) != MP_OK)
420       return res;
421   }
422 
423   MP_USED(z) = uold;
424   MP_SIGN(z) = MP_SIGN(old);
425   COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
426 
427   return MP_OK;
428 }
429 
mp_int_init_value(mp_int z,mp_small value)430 mp_result mp_int_init_value(mp_int z, mp_small value)
431 {
432   mpz_t    vtmp;
433   mp_digit vbuf[MP_VALUE_DIGITS(value)];
434 
435   s_fake(&vtmp, value, vbuf);
436   return mp_int_init_copy(z, &vtmp);
437 }
438 
mp_int_init_uvalue(mp_int z,mp_usmall uvalue)439 mp_result mp_int_init_uvalue(mp_int z, mp_usmall uvalue)
440 {
441   mpz_t    vtmp;
442   mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
443 
444   s_ufake(&vtmp, uvalue, vbuf);
445   return mp_int_init_copy(z, &vtmp);
446 }
447 
mp_int_set_value(mp_int z,mp_small value)448 mp_result  mp_int_set_value(mp_int z, mp_small value)
449 {
450   mpz_t    vtmp;
451   mp_digit vbuf[MP_VALUE_DIGITS(value)];
452 
453   s_fake(&vtmp, value, vbuf);
454   return mp_int_copy(&vtmp, z);
455 }
456 
mp_int_set_uvalue(mp_int z,mp_usmall uvalue)457 mp_result  mp_int_set_uvalue(mp_int z, mp_usmall uvalue)
458 {
459   mpz_t    vtmp;
460   mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
461 
462   s_ufake(&vtmp, uvalue, vbuf);
463   return mp_int_copy(&vtmp, z);
464 }
465 
mp_int_clear(mp_int z)466 void      mp_int_clear(mp_int z)
467 {
468   if (z == NULL)
469     return;
470 
471   if (MP_DIGITS(z) != NULL) {
472     if (MP_DIGITS(z) != &(z->single))
473       s_free(MP_DIGITS(z));
474 
475     MP_DIGITS(z) = NULL;
476   }
477 }
478 
mp_int_free(mp_int z)479 void      mp_int_free(mp_int z)
480 {
481   NRCHECK(z != NULL);
482 
483   mp_int_clear(z);
484   free(z); /* note: NOT s_free() */
485 }
486 
mp_int_copy(mp_int a,mp_int c)487 mp_result mp_int_copy(mp_int a, mp_int c)
488 {
489   CHECK(a != NULL && c != NULL);
490 
491   if (a != c) {
492     mp_size ua = MP_USED(a);
493     mp_digit *da, *dc;
494 
495     if (!s_pad(c, ua))
496       return MP_MEMORY;
497 
498     da = MP_DIGITS(a); dc = MP_DIGITS(c);
499     COPY(da, dc, ua);
500 
501     MP_USED(c) = ua;
502     MP_SIGN(c) = MP_SIGN(a);
503   }
504 
505   return MP_OK;
506 }
507 
mp_int_swap(mp_int a,mp_int c)508 void      mp_int_swap(mp_int a, mp_int c)
509 {
510   if (a != c) {
511     mpz_t tmp = *a;
512 
513     *a = *c;
514     *c = tmp;
515 
516     if (MP_DIGITS(a) == &(c->single))
517       MP_DIGITS(a) = &(a->single);
518     if (MP_DIGITS(c) == &(a->single))
519       MP_DIGITS(c) = &(c->single);
520   }
521 }
522 
mp_int_zero(mp_int z)523 void      mp_int_zero(mp_int z)
524 {
525   NRCHECK(z != NULL);
526 
527   z->digits[0] = 0;
528   MP_USED(z) = 1;
529   MP_SIGN(z) = MP_ZPOS;
530 }
531 
mp_int_abs(mp_int a,mp_int c)532 mp_result mp_int_abs(mp_int a, mp_int c)
533 {
534   mp_result res;
535 
536   CHECK(a != NULL && c != NULL);
537 
538   if ((res = mp_int_copy(a, c)) != MP_OK)
539     return res;
540 
541   MP_SIGN(c) = MP_ZPOS;
542   return MP_OK;
543 }
544 
mp_int_neg(mp_int a,mp_int c)545 mp_result mp_int_neg(mp_int a, mp_int c)
546 {
547   mp_result res;
548 
549   CHECK(a != NULL && c != NULL);
550 
551   if ((res = mp_int_copy(a, c)) != MP_OK)
552     return res;
553 
554   if (CMPZ(c) != 0)
555     MP_SIGN(c) = 1 - MP_SIGN(a);
556 
557   return MP_OK;
558 }
559 
mp_int_add(mp_int a,mp_int b,mp_int c)560 mp_result mp_int_add(mp_int a, mp_int b, mp_int c)
561 {
562   mp_size ua, ub, uc, max;
563 
564   CHECK(a != NULL && b != NULL && c != NULL);
565 
566   ua = MP_USED(a); ub = MP_USED(b); uc = MP_USED(c);
567   max = MAX(ua, ub);
568 
569   if (MP_SIGN(a) == MP_SIGN(b)) {
570     /* Same sign -- add magnitudes, preserve sign of addends */
571     mp_digit carry;
572 
573     if (!s_pad(c, max))
574       return MP_MEMORY;
575 
576     carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
577     uc = max;
578 
579     if (carry) {
580       if (!s_pad(c, max + 1))
581 	return MP_MEMORY;
582 
583       c->digits[max] = carry;
584       ++uc;
585     }
586 
587     MP_USED(c) = uc;
588     MP_SIGN(c) = MP_SIGN(a);
589 
590   }
591   else {
592     /* Different signs -- subtract magnitudes, preserve sign of greater */
593     mp_int  x, y;
594     int     cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */
595 
596     /* Set x to max(a, b), y to min(a, b) to simplify later code.
597        A special case yields zero for equal magnitudes.
598     */
599     if (cmp == 0) {
600       mp_int_zero(c);
601       return MP_OK;
602     }
603     else if (cmp < 0) {
604       x = b; y = a;
605     }
606     else {
607       x = a; y = b;
608     }
609 
610     if (!s_pad(c, MP_USED(x)))
611       return MP_MEMORY;
612 
613     /* Subtract smaller from larger */
614     s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
615     MP_USED(c) = MP_USED(x);
616     CLAMP(c);
617 
618     /* Give result the sign of the larger */
619     MP_SIGN(c) = MP_SIGN(x);
620   }
621 
622   return MP_OK;
623 }
624 
mp_int_add_value(mp_int a,mp_small value,mp_int c)625 mp_result mp_int_add_value(mp_int a, mp_small value, mp_int c)
626 {
627   mpz_t    vtmp;
628   mp_digit vbuf[MP_VALUE_DIGITS(value)];
629 
630   s_fake(&vtmp, value, vbuf);
631 
632   return mp_int_add(a, &vtmp, c);
633 }
634 
mp_int_sub(mp_int a,mp_int b,mp_int c)635 mp_result mp_int_sub(mp_int a, mp_int b, mp_int c)
636 {
637   mp_size ua, ub, uc, max;
638 
639   CHECK(a != NULL && b != NULL && c != NULL);
640 
641   ua = MP_USED(a); ub = MP_USED(b); uc = MP_USED(c);
642   max = MAX(ua, ub);
643 
644   if (MP_SIGN(a) != MP_SIGN(b)) {
645     /* Different signs -- add magnitudes and keep sign of a */
646     mp_digit carry;
647 
648     if (!s_pad(c, max))
649       return MP_MEMORY;
650 
651     carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
652     uc = max;
653 
654     if (carry) {
655       if (!s_pad(c, max + 1))
656 	return MP_MEMORY;
657 
658       c->digits[max] = carry;
659       ++uc;
660     }
661 
662     MP_USED(c) = uc;
663     MP_SIGN(c) = MP_SIGN(a);
664 
665   }
666   else {
667     /* Same signs -- subtract magnitudes */
668     mp_int  x, y;
669     mp_sign osign;
670     int     cmp = s_ucmp(a, b);
671 
672     if (!s_pad(c, max))
673       return MP_MEMORY;
674 
675     if (cmp >= 0) {
676       x = a; y = b; osign = MP_ZPOS;
677     }
678     else {
679       x = b; y = a; osign = MP_NEG;
680     }
681 
682     if (MP_SIGN(a) == MP_NEG && cmp != 0)
683       osign = 1 - osign;
684 
685     s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
686     MP_USED(c) = MP_USED(x);
687     CLAMP(c);
688 
689     MP_SIGN(c) = osign;
690   }
691 
692   return MP_OK;
693 }
694 
mp_int_sub_value(mp_int a,mp_small value,mp_int c)695 mp_result mp_int_sub_value(mp_int a, mp_small value, mp_int c)
696 {
697   mpz_t    vtmp;
698   mp_digit vbuf[MP_VALUE_DIGITS(value)];
699 
700   s_fake(&vtmp, value, vbuf);
701 
702   return mp_int_sub(a, &vtmp, c);
703 }
704 
mp_int_mul(mp_int a,mp_int b,mp_int c)705 mp_result mp_int_mul(mp_int a, mp_int b, mp_int c)
706 {
707   mp_digit *out;
708   mp_size   osize, ua, ub, p = 0;
709   mp_sign   osign;
710 
711   CHECK(a != NULL && b != NULL && c != NULL);
712 
713   /* If either input is zero, we can shortcut multiplication */
714   if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) {
715     mp_int_zero(c);
716     return MP_OK;
717   }
718 
719   /* Output is positive if inputs have same sign, otherwise negative */
720   osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
721 
722   /* If the output is not identical to any of the inputs, we'll write the
723      results directly; otherwise, allocate a temporary space. */
724   ua = MP_USED(a); ub = MP_USED(b);
725   osize = MAX(ua, ub);
726   osize = 4 * ((osize + 1) / 2);
727 
728   if (c == a || c == b) {
729     p = ROUND_PREC(osize);
730     p = MAX(p, default_precision);
731 
732     if ((out = s_alloc(p)) == NULL)
733       return MP_MEMORY;
734   }
735   else {
736     if (!s_pad(c, osize))
737       return MP_MEMORY;
738 
739     out = MP_DIGITS(c);
740   }
741   ZERO(out, osize);
742 
743   if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
744     return MP_MEMORY;
745 
746   /* If we allocated a new buffer, get rid of whatever memory c was already
747      using, and fix up its fields to reflect that.
748    */
749   if (out != MP_DIGITS(c)) {
750     if ((void *) MP_DIGITS(c) != (void *) c)
751       s_free(MP_DIGITS(c));
752     MP_DIGITS(c) = out;
753     MP_ALLOC(c) = p;
754   }
755 
756   MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
757   CLAMP(c);           /* ... right here */
758   MP_SIGN(c) = osign;
759 
760   return MP_OK;
761 }
762 
mp_int_mul_value(mp_int a,mp_small value,mp_int c)763 mp_result mp_int_mul_value(mp_int a, mp_small value, mp_int c)
764 {
765   mpz_t    vtmp;
766   mp_digit vbuf[MP_VALUE_DIGITS(value)];
767 
768   s_fake(&vtmp, value, vbuf);
769 
770   return mp_int_mul(a, &vtmp, c);
771 }
772 
mp_int_mul_pow2(mp_int a,mp_small p2,mp_int c)773 mp_result mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c)
774 {
775   mp_result res;
776   CHECK(a != NULL && c != NULL && p2 >= 0);
777 
778   if ((res = mp_int_copy(a, c)) != MP_OK)
779     return res;
780 
781   if (s_qmul(c, (mp_size) p2))
782     return MP_OK;
783   else
784     return MP_MEMORY;
785 }
786 
mp_int_sqr(mp_int a,mp_int c)787 mp_result mp_int_sqr(mp_int a, mp_int c)
788 {
789   mp_digit *out;
790   mp_size   osize, p = 0;
791 
792   CHECK(a != NULL && c != NULL);
793 
794   /* Get a temporary buffer big enough to hold the result */
795   osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2);
796   if (a == c) {
797     p = ROUND_PREC(osize);
798     p = MAX(p, default_precision);
799 
800     if ((out = s_alloc(p)) == NULL)
801       return MP_MEMORY;
802   }
803   else {
804     if (!s_pad(c, osize))
805       return MP_MEMORY;
806 
807     out = MP_DIGITS(c);
808   }
809   ZERO(out, osize);
810 
811   s_ksqr(MP_DIGITS(a), out, MP_USED(a));
812 
813   /* Get rid of whatever memory c was already using, and fix up its fields to
814      reflect the new digit array it's using
815    */
816   if (out != MP_DIGITS(c)) {
817     if ((void *) MP_DIGITS(c) != (void *) c)
818       s_free(MP_DIGITS(c));
819     MP_DIGITS(c) = out;
820     MP_ALLOC(c) = p;
821   }
822 
823   MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
824   CLAMP(c);           /* ... right here */
825   MP_SIGN(c) = MP_ZPOS;
826 
827   return MP_OK;
828 }
829 
mp_int_div(mp_int a,mp_int b,mp_int q,mp_int r)830 mp_result mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
831 {
832   int cmp, lg;
833   mp_result res = MP_OK;
834   mp_int qout, rout;
835   mp_sign sa = MP_SIGN(a), sb = MP_SIGN(b);
836   DECLARE_TEMP(2);
837 
838   CHECK(a != NULL && b != NULL && q != r);
839 
840   if (CMPZ(b) == 0)
841     return MP_UNDEF;
842   else if ((cmp = s_ucmp(a, b)) < 0) {
843     /* If |a| < |b|, no division is required:
844        q = 0, r = a
845      */
846     if (r && (res = mp_int_copy(a, r)) != MP_OK)
847       return res;
848 
849     if (q)
850       mp_int_zero(q);
851 
852     return MP_OK;
853   }
854   else if (cmp == 0) {
855     /* If |a| = |b|, no division is required:
856        q = 1 or -1, r = 0
857      */
858     if (r)
859       mp_int_zero(r);
860 
861     if (q) {
862       mp_int_zero(q);
863       q->digits[0] = 1;
864 
865       if (sa != sb)
866 	MP_SIGN(q) = MP_NEG;
867     }
868 
869     return MP_OK;
870   }
871 
872   /* When |a| > |b|, real division is required.  We need someplace to store
873      quotient and remainder, but q and r are allowed to be NULL or to overlap
874      with the inputs.
875    */
876   if ((lg = s_isp2(b)) < 0) {
877     if (q && b != q) {
878       if ((res = mp_int_copy(a, q)) != MP_OK)
879 	goto CLEANUP;
880       else
881 	qout = q;
882     }
883     else {
884       qout = LAST_TEMP();
885       SETUP(mp_int_init_copy(LAST_TEMP(), a));
886     }
887 
888     if (r && a != r) {
889       if ((res = mp_int_copy(b, r)) != MP_OK)
890 	goto CLEANUP;
891       else
892 	rout = r;
893     }
894     else {
895       rout = LAST_TEMP();
896       SETUP(mp_int_init_copy(LAST_TEMP(), b));
897     }
898 
899     if ((res = s_udiv_knuth(qout, rout)) != MP_OK) goto CLEANUP;
900   }
901   else {
902     if (q && (res = mp_int_copy(a, q)) != MP_OK) goto CLEANUP;
903     if (r && (res = mp_int_copy(a, r)) != MP_OK) goto CLEANUP;
904 
905     if (q) s_qdiv(q, (mp_size) lg); qout = q;
906     if (r) s_qmod(r, (mp_size) lg); rout = r;
907   }
908 
909   /* Recompute signs for output */
910   if (rout) {
911     MP_SIGN(rout) = sa;
912     if (CMPZ(rout) == 0)
913       MP_SIGN(rout) = MP_ZPOS;
914   }
915   if (qout) {
916     MP_SIGN(qout) = (sa == sb) ? MP_ZPOS : MP_NEG;
917     if (CMPZ(qout) == 0)
918       MP_SIGN(qout) = MP_ZPOS;
919   }
920 
921   if (q && (res = mp_int_copy(qout, q)) != MP_OK) goto CLEANUP;
922   if (r && (res = mp_int_copy(rout, r)) != MP_OK) goto CLEANUP;
923 
924   CLEANUP_TEMP();
925   return res;
926 }
927 
mp_int_mod(mp_int a,mp_int m,mp_int c)928 mp_result mp_int_mod(mp_int a, mp_int m, mp_int c)
929 {
930   mp_result res;
931   mpz_t     tmp;
932   mp_int    out;
933 
934   if (m == c) {
935     mp_int_init(&tmp);
936     out = &tmp;
937   }
938   else {
939     out = c;
940   }
941 
942   if ((res = mp_int_div(a, m, NULL, out)) != MP_OK)
943     goto CLEANUP;
944 
945   if (CMPZ(out) < 0)
946     res = mp_int_add(out, m, c);
947   else
948     res = mp_int_copy(out, c);
949 
950  CLEANUP:
951   if (out != c)
952     mp_int_clear(&tmp);
953 
954   return res;
955 }
956 
mp_int_div_value(mp_int a,mp_small value,mp_int q,mp_small * r)957 mp_result mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r)
958 {
959   mpz_t     vtmp, rtmp;
960   mp_digit  vbuf[MP_VALUE_DIGITS(value)];
961   mp_result res;
962 
963   mp_int_init(&rtmp);
964   s_fake(&vtmp, value, vbuf);
965 
966   if ((res = mp_int_div(a, &vtmp, q, &rtmp)) != MP_OK)
967     goto CLEANUP;
968 
969   if (r)
970     (void) mp_int_to_int(&rtmp, r); /* can't fail */
971 
972  CLEANUP:
973   mp_int_clear(&rtmp);
974   return res;
975 }
976 
mp_int_div_pow2(mp_int a,mp_small p2,mp_int q,mp_int r)977 mp_result mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r)
978 {
979   mp_result res = MP_OK;
980 
981   CHECK(a != NULL && p2 >= 0 && q != r);
982 
983   if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
984     s_qdiv(q, (mp_size) p2);
985 
986   if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
987     s_qmod(r, (mp_size) p2);
988 
989   return res;
990 }
991 
mp_int_expt(mp_int a,mp_small b,mp_int c)992 mp_result mp_int_expt(mp_int a, mp_small b, mp_int c)
993 {
994   mpz_t t;
995   mp_result res;
996   unsigned int v = labs(b);
997 
998   CHECK(c != NULL);
999   if (b < 0)
1000     return MP_RANGE;
1001 
1002   if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1003     return res;
1004 
1005   (void) mp_int_set_value(c, 1);
1006   while (v != 0) {
1007     if (v & 1) {
1008       if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1009 	goto CLEANUP;
1010     }
1011 
1012     v >>= 1;
1013     if (v == 0) break;
1014 
1015     if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1016       goto CLEANUP;
1017   }
1018 
1019  CLEANUP:
1020   mp_int_clear(&t);
1021   return res;
1022 }
1023 
mp_int_expt_value(mp_small a,mp_small b,mp_int c)1024 mp_result mp_int_expt_value(mp_small a, mp_small b, mp_int c)
1025 {
1026   mpz_t     t;
1027   mp_result res;
1028   unsigned int v = labs(b);
1029 
1030   CHECK(c != NULL);
1031   if (b < 0)
1032     return MP_RANGE;
1033 
1034   if ((res = mp_int_init_value(&t, a)) != MP_OK)
1035     return res;
1036 
1037   (void) mp_int_set_value(c, 1);
1038   while (v != 0) {
1039     if (v & 1) {
1040       if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1041 	goto CLEANUP;
1042     }
1043 
1044     v >>= 1;
1045     if (v == 0) break;
1046 
1047     if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1048       goto CLEANUP;
1049   }
1050 
1051  CLEANUP:
1052   mp_int_clear(&t);
1053   return res;
1054 }
1055 
mp_int_expt_full(mp_int a,mp_int b,mp_int c)1056 mp_result mp_int_expt_full(mp_int a, mp_int b, mp_int c)
1057 {
1058   mpz_t t;
1059   mp_result res;
1060   unsigned ix, jx;
1061 
1062   CHECK(a != NULL && b != NULL && c != NULL);
1063   if (MP_SIGN(b) == MP_NEG)
1064     return MP_RANGE;
1065 
1066   if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1067     return res;
1068 
1069   (void) mp_int_set_value(c, 1);
1070   for (ix = 0; ix < MP_USED(b); ++ix) {
1071     mp_digit d = b->digits[ix];
1072 
1073     for (jx = 0; jx < MP_DIGIT_BIT; ++jx) {
1074       if (d & 1) {
1075 	if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1076 	  goto CLEANUP;
1077       }
1078 
1079       d >>= 1;
1080       if (d == 0 && ix + 1 == MP_USED(b))
1081 	break;
1082       if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1083 	goto CLEANUP;
1084     }
1085   }
1086 
1087  CLEANUP:
1088   mp_int_clear(&t);
1089   return res;
1090 }
1091 
mp_int_compare(mp_int a,mp_int b)1092 int       mp_int_compare(mp_int a, mp_int b)
1093 {
1094   mp_sign sa;
1095 
1096   CHECK(a != NULL && b != NULL);
1097 
1098   sa = MP_SIGN(a);
1099   if (sa == MP_SIGN(b)) {
1100     int cmp = s_ucmp(a, b);
1101 
1102     /* If they're both zero or positive, the normal comparison applies; if both
1103        negative, the sense is reversed. */
1104     if (sa == MP_ZPOS)
1105       return cmp;
1106     else
1107       return -cmp;
1108 
1109   }
1110   else {
1111     if (sa == MP_ZPOS)
1112       return 1;
1113     else
1114       return -1;
1115   }
1116 }
1117 
mp_int_compare_unsigned(mp_int a,mp_int b)1118 int       mp_int_compare_unsigned(mp_int a, mp_int b)
1119 {
1120   NRCHECK(a != NULL && b != NULL);
1121 
1122   return s_ucmp(a, b);
1123 }
1124 
mp_int_compare_zero(mp_int z)1125 int       mp_int_compare_zero(mp_int z)
1126 {
1127   NRCHECK(z != NULL);
1128 
1129   if (MP_USED(z) == 1 && z->digits[0] == 0)
1130     return 0;
1131   else if (MP_SIGN(z) == MP_ZPOS)
1132     return 1;
1133   else
1134     return -1;
1135 }
1136 
mp_int_compare_value(mp_int z,mp_small value)1137 int       mp_int_compare_value(mp_int z, mp_small value)
1138 {
1139   mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1140   int cmp;
1141 
1142   CHECK(z != NULL);
1143 
1144   if (vsign == MP_SIGN(z)) {
1145     cmp = s_vcmp(z, value);
1146 
1147     return (vsign == MP_ZPOS) ? cmp : -cmp;
1148   }
1149   else {
1150     return (value < 0) ? 1 : -1;
1151   }
1152 }
1153 
mp_int_compare_uvalue(mp_int z,mp_usmall uv)1154 int       mp_int_compare_uvalue(mp_int z, mp_usmall uv)
1155 {
1156   CHECK(z != NULL);
1157 
1158   if (MP_SIGN(z) == MP_NEG)
1159     return -1;
1160   else
1161     return s_uvcmp(z, uv);
1162 }
1163 
mp_int_exptmod(mp_int a,mp_int b,mp_int m,mp_int c)1164 mp_result mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1165 {
1166   mp_result res;
1167   mp_size um;
1168   mp_int s;
1169   DECLARE_TEMP(3);
1170 
1171   CHECK(a != NULL && b != NULL && c != NULL && m != NULL);
1172 
1173   /* Zero moduli and negative exponents are not considered. */
1174   if (CMPZ(m) == 0)
1175     return MP_UNDEF;
1176   if (CMPZ(b) < 0)
1177     return MP_RANGE;
1178 
1179   um = MP_USED(m);
1180   SETUP(mp_int_init_size(TEMP(0), 2 * um));
1181   SETUP(mp_int_init_size(TEMP(1), 2 * um));
1182 
1183   if (c == b || c == m) {
1184     SETUP(mp_int_init_size(TEMP(2), 2 * um));
1185     s = TEMP(2);
1186   }
1187   else {
1188     s = c;
1189   }
1190 
1191   if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK) goto CLEANUP;
1192 
1193   if ((res = s_brmu(TEMP(1), m)) != MP_OK) goto CLEANUP;
1194 
1195   if ((res = s_embar(TEMP(0), b, m, TEMP(1), s)) != MP_OK)
1196     goto CLEANUP;
1197 
1198   res = mp_int_copy(s, c);
1199 
1200   CLEANUP_TEMP();
1201   return res;
1202 }
1203 
mp_int_exptmod_evalue(mp_int a,mp_small value,mp_int m,mp_int c)1204 mp_result mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c)
1205 {
1206   mpz_t vtmp;
1207   mp_digit vbuf[MP_VALUE_DIGITS(value)];
1208 
1209   s_fake(&vtmp, value, vbuf);
1210 
1211   return mp_int_exptmod(a, &vtmp, m, c);
1212 }
1213 
mp_int_exptmod_bvalue(mp_small value,mp_int b,mp_int m,mp_int c)1214 mp_result mp_int_exptmod_bvalue(mp_small value, mp_int b,
1215 				mp_int m, mp_int c)
1216 {
1217   mpz_t vtmp;
1218   mp_digit vbuf[MP_VALUE_DIGITS(value)];
1219 
1220   s_fake(&vtmp, value, vbuf);
1221 
1222   return mp_int_exptmod(&vtmp, b, m, c);
1223 }
1224 
mp_int_exptmod_known(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)1225 mp_result mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
1226 {
1227   mp_result res;
1228   mp_size um;
1229   mp_int s;
1230   DECLARE_TEMP(2);
1231 
1232   CHECK(a && b && m && c);
1233 
1234   /* Zero moduli and negative exponents are not considered. */
1235   if (CMPZ(m) == 0)
1236     return MP_UNDEF;
1237   if (CMPZ(b) < 0)
1238     return MP_RANGE;
1239 
1240   um = MP_USED(m);
1241   SETUP(mp_int_init_size(TEMP(0), 2 * um));
1242 
1243   if (c == b || c == m) {
1244     SETUP(mp_int_init_size(TEMP(1), 2 * um));
1245     s = TEMP(1);
1246   }
1247   else {
1248     s = c;
1249   }
1250 
1251   if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK) goto CLEANUP;
1252 
1253   if ((res = s_embar(TEMP(0), b, m, mu, s)) != MP_OK)
1254     goto CLEANUP;
1255 
1256   res = mp_int_copy(s, c);
1257 
1258   CLEANUP_TEMP();
1259   return res;
1260 }
1261 
mp_int_redux_const(mp_int m,mp_int c)1262 mp_result mp_int_redux_const(mp_int m, mp_int c)
1263 {
1264   CHECK(m != NULL && c != NULL && m != c);
1265 
1266   return s_brmu(c, m);
1267 }
1268 
mp_int_invmod(mp_int a,mp_int m,mp_int c)1269 mp_result mp_int_invmod(mp_int a, mp_int m, mp_int c)
1270 {
1271   mp_result res;
1272   mp_sign sa;
1273   DECLARE_TEMP(2);
1274 
1275   CHECK(a != NULL && m != NULL && c != NULL);
1276 
1277   if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1278     return MP_RANGE;
1279 
1280   sa = MP_SIGN(a); /* need this for the result later */
1281 
1282   for (last__ = 0; last__ < 2; ++last__)
1283     mp_int_init(LAST_TEMP());
1284 
1285   if ((res = mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)) != MP_OK)
1286     goto CLEANUP;
1287 
1288   if (mp_int_compare_value(TEMP(0), 1) != 0) {
1289     res = MP_UNDEF;
1290     goto CLEANUP;
1291   }
1292 
1293   /* It is first necessary to constrain the value to the proper range */
1294   if ((res = mp_int_mod(TEMP(1), m, TEMP(1))) != MP_OK)
1295     goto CLEANUP;
1296 
1297   /* Now, if 'a' was originally negative, the value we have is actually the
1298      magnitude of the negative representative; to get the positive value we
1299      have to subtract from the modulus.  Otherwise, the value is okay as it
1300      stands.
1301    */
1302   if (sa == MP_NEG)
1303     res = mp_int_sub(m, TEMP(1), c);
1304   else
1305     res = mp_int_copy(TEMP(1), c);
1306 
1307   CLEANUP_TEMP();
1308   return res;
1309 }
1310 
1311 /* Binary GCD algorithm due to Josef Stein, 1961 */
mp_int_gcd(mp_int a,mp_int b,mp_int c)1312 mp_result mp_int_gcd(mp_int a, mp_int b, mp_int c)
1313 {
1314   int ca, cb, k = 0;
1315   mpz_t u, v, t;
1316   mp_result res;
1317 
1318   CHECK(a != NULL && b != NULL && c != NULL);
1319 
1320   ca = CMPZ(a);
1321   cb = CMPZ(b);
1322   if (ca == 0 && cb == 0)
1323     return MP_UNDEF;
1324   else if (ca == 0)
1325     return mp_int_abs(b, c);
1326   else if (cb == 0)
1327     return mp_int_abs(a, c);
1328 
1329   mp_int_init(&t);
1330   if ((res = mp_int_init_copy(&u, a)) != MP_OK)
1331     goto U;
1332   if ((res = mp_int_init_copy(&v, b)) != MP_OK)
1333     goto V;
1334 
1335   MP_SIGN(&u) = MP_ZPOS; MP_SIGN(&v) = MP_ZPOS;
1336 
1337   { /* Divide out common factors of 2 from u and v */
1338     int div2_u = s_dp2k(&u), div2_v = s_dp2k(&v);
1339 
1340     k = MIN(div2_u, div2_v);
1341     s_qdiv(&u, (mp_size) k);
1342     s_qdiv(&v, (mp_size) k);
1343   }
1344 
1345   if (mp_int_is_odd(&u)) {
1346     if ((res = mp_int_neg(&v, &t)) != MP_OK)
1347       goto CLEANUP;
1348   }
1349   else {
1350     if ((res = mp_int_copy(&u, &t)) != MP_OK)
1351       goto CLEANUP;
1352   }
1353 
1354   for (;;) {
1355     s_qdiv(&t, s_dp2k(&t));
1356 
1357     if (CMPZ(&t) > 0) {
1358       if ((res = mp_int_copy(&t, &u)) != MP_OK)
1359 	goto CLEANUP;
1360     }
1361     else {
1362       if ((res = mp_int_neg(&t, &v)) != MP_OK)
1363 	goto CLEANUP;
1364     }
1365 
1366     if ((res = mp_int_sub(&u, &v, &t)) != MP_OK)
1367       goto CLEANUP;
1368 
1369     if (CMPZ(&t) == 0)
1370       break;
1371   }
1372 
1373   if ((res = mp_int_abs(&u, c)) != MP_OK)
1374     goto CLEANUP;
1375   if (!s_qmul(c, (mp_size) k))
1376     res = MP_MEMORY;
1377 
1378  CLEANUP:
1379   mp_int_clear(&v);
1380  V: mp_int_clear(&u);
1381  U: mp_int_clear(&t);
1382 
1383   return res;
1384 }
1385 
1386 /* This is the binary GCD algorithm again, but this time we keep track of the
1387    elementary matrix operations as we go, so we can get values x and y
1388    satisfying c = ax + by.
1389  */
mp_int_egcd(mp_int a,mp_int b,mp_int c,mp_int x,mp_int y)1390 mp_result mp_int_egcd(mp_int a, mp_int b, mp_int c,
1391 		      mp_int x, mp_int y)
1392 {
1393   int k, ca, cb;
1394   mp_result res;
1395   DECLARE_TEMP(8);
1396 
1397   CHECK(a != NULL && b != NULL && c != NULL &&
1398 	(x != NULL || y != NULL));
1399 
1400   ca = CMPZ(a);
1401   cb = CMPZ(b);
1402   if (ca == 0 && cb == 0)
1403     return MP_UNDEF;
1404   else if (ca == 0) {
1405     if ((res = mp_int_abs(b, c)) != MP_OK) return res;
1406     mp_int_zero(x); (void) mp_int_set_value(y, 1); return MP_OK;
1407   }
1408   else if (cb == 0) {
1409     if ((res = mp_int_abs(a, c)) != MP_OK) return res;
1410     (void) mp_int_set_value(x, 1); mp_int_zero(y); return MP_OK;
1411   }
1412 
1413   /* Initialize temporaries:
1414      A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 */
1415   for (last__ = 0; last__ < 4; ++last__)
1416     mp_int_init(LAST_TEMP());
1417   TEMP(0)->digits[0] = 1;
1418   TEMP(3)->digits[0] = 1;
1419 
1420   SETUP(mp_int_init_copy(TEMP(4), a));
1421   SETUP(mp_int_init_copy(TEMP(5), b));
1422 
1423   /* We will work with absolute values here */
1424   MP_SIGN(TEMP(4)) = MP_ZPOS;
1425   MP_SIGN(TEMP(5)) = MP_ZPOS;
1426 
1427   { /* Divide out common factors of 2 from u and v */
1428     int  div2_u = s_dp2k(TEMP(4)), div2_v = s_dp2k(TEMP(5));
1429 
1430     k = MIN(div2_u, div2_v);
1431     s_qdiv(TEMP(4), k);
1432     s_qdiv(TEMP(5), k);
1433   }
1434 
1435   SETUP(mp_int_init_copy(TEMP(6), TEMP(4)));
1436   SETUP(mp_int_init_copy(TEMP(7), TEMP(5)));
1437 
1438   for (;;) {
1439     while (mp_int_is_even(TEMP(4))) {
1440       s_qdiv(TEMP(4), 1);
1441 
1442       if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) {
1443 	if ((res = mp_int_add(TEMP(0), TEMP(7), TEMP(0))) != MP_OK)
1444 	  goto CLEANUP;
1445 	if ((res = mp_int_sub(TEMP(1), TEMP(6), TEMP(1))) != MP_OK)
1446 	  goto CLEANUP;
1447       }
1448 
1449       s_qdiv(TEMP(0), 1);
1450       s_qdiv(TEMP(1), 1);
1451     }
1452 
1453     while (mp_int_is_even(TEMP(5))) {
1454       s_qdiv(TEMP(5), 1);
1455 
1456       if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) {
1457 	if ((res = mp_int_add(TEMP(2), TEMP(7), TEMP(2))) != MP_OK)
1458 	  goto CLEANUP;
1459 	if ((res = mp_int_sub(TEMP(3), TEMP(6), TEMP(3))) != MP_OK)
1460 	  goto CLEANUP;
1461       }
1462 
1463       s_qdiv(TEMP(2), 1);
1464       s_qdiv(TEMP(3), 1);
1465     }
1466 
1467     if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) {
1468       if ((res = mp_int_sub(TEMP(4), TEMP(5), TEMP(4))) != MP_OK) goto CLEANUP;
1469       if ((res = mp_int_sub(TEMP(0), TEMP(2), TEMP(0))) != MP_OK) goto CLEANUP;
1470       if ((res = mp_int_sub(TEMP(1), TEMP(3), TEMP(1))) != MP_OK) goto CLEANUP;
1471     }
1472     else {
1473       if ((res = mp_int_sub(TEMP(5), TEMP(4), TEMP(5))) != MP_OK) goto CLEANUP;
1474       if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK) goto CLEANUP;
1475       if ((res = mp_int_sub(TEMP(3), TEMP(1), TEMP(3))) != MP_OK) goto CLEANUP;
1476     }
1477 
1478     if (CMPZ(TEMP(4)) == 0) {
1479       if (x && (res = mp_int_copy(TEMP(2), x)) != MP_OK) goto CLEANUP;
1480       if (y && (res = mp_int_copy(TEMP(3), y)) != MP_OK) goto CLEANUP;
1481       if (c) {
1482 	if (!s_qmul(TEMP(5), k)) {
1483 	  res = MP_MEMORY;
1484 	  goto CLEANUP;
1485 	}
1486 
1487 	res = mp_int_copy(TEMP(5), c);
1488       }
1489 
1490       break;
1491     }
1492   }
1493 
1494   CLEANUP_TEMP();
1495   return res;
1496 }
1497 
mp_int_lcm(mp_int a,mp_int b,mp_int c)1498 mp_result mp_int_lcm(mp_int a, mp_int b, mp_int c)
1499 {
1500   mpz_t lcm;
1501   mp_result res;
1502 
1503   CHECK(a != NULL && b != NULL && c != NULL);
1504 
1505   /* Since a * b = gcd(a, b) * lcm(a, b), we can compute
1506      lcm(a, b) = (a / gcd(a, b)) * b.
1507 
1508      This formulation insures everything works even if the input
1509      variables share space.
1510    */
1511   if ((res = mp_int_init(&lcm)) != MP_OK)
1512     return res;
1513   if ((res = mp_int_gcd(a, b, &lcm)) != MP_OK)
1514     goto CLEANUP;
1515   if ((res = mp_int_div(a, &lcm, &lcm, NULL)) != MP_OK)
1516     goto CLEANUP;
1517   if ((res = mp_int_mul(&lcm, b, &lcm)) != MP_OK)
1518     goto CLEANUP;
1519 
1520   res = mp_int_copy(&lcm, c);
1521 
1522   CLEANUP:
1523     mp_int_clear(&lcm);
1524 
1525   return res;
1526 }
1527 
mp_int_divisible_value(mp_int a,mp_small v)1528 int       mp_int_divisible_value(mp_int a, mp_small v)
1529 {
1530   mp_small rem = 0;
1531 
1532   if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1533     return 0;
1534 
1535   return rem == 0;
1536 }
1537 
mp_int_is_pow2(mp_int z)1538 int       mp_int_is_pow2(mp_int z)
1539 {
1540   CHECK(z != NULL);
1541 
1542   return s_isp2(z);
1543 }
1544 
1545 /* Implementation of Newton's root finding method, based loosely on a patch
1546    contributed by Hal Finkel <half@halssoftware.com>
1547    modified by M. J. Fromberger.
1548  */
mp_int_root(mp_int a,mp_small b,mp_int c)1549 mp_result mp_int_root(mp_int a, mp_small b, mp_int c)
1550 {
1551   mp_result res = MP_OK;
1552   int flips = 0;
1553   DECLARE_TEMP(5);
1554 
1555   CHECK(a != NULL && c != NULL && b > 0);
1556 
1557   if (b == 1) {
1558     return mp_int_copy(a, c);
1559   }
1560   if (MP_SIGN(a) == MP_NEG) {
1561     if (b % 2 == 0)
1562       return MP_UNDEF; /* root does not exist for negative a with even b */
1563     else
1564       flips = 1;
1565   }
1566 
1567   SETUP(mp_int_init_copy(LAST_TEMP(), a));
1568   SETUP(mp_int_init_copy(LAST_TEMP(), a));
1569   SETUP(mp_int_init(LAST_TEMP()));
1570   SETUP(mp_int_init(LAST_TEMP()));
1571   SETUP(mp_int_init(LAST_TEMP()));
1572 
1573   (void) mp_int_abs(TEMP(0), TEMP(0));
1574   (void) mp_int_abs(TEMP(1), TEMP(1));
1575 
1576   for (;;) {
1577     if ((res = mp_int_expt(TEMP(1), b, TEMP(2))) != MP_OK)
1578       goto CLEANUP;
1579 
1580     if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0)
1581       break;
1582 
1583     if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK)
1584       goto CLEANUP;
1585     if ((res = mp_int_expt(TEMP(1), b - 1, TEMP(3))) != MP_OK)
1586       goto CLEANUP;
1587     if ((res = mp_int_mul_value(TEMP(3), b, TEMP(3))) != MP_OK)
1588       goto CLEANUP;
1589     if ((res = mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL)) != MP_OK)
1590       goto CLEANUP;
1591     if ((res = mp_int_sub(TEMP(1), TEMP(4), TEMP(4))) != MP_OK)
1592       goto CLEANUP;
1593 
1594     if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) {
1595       if ((res = mp_int_sub_value(TEMP(4), 1, TEMP(4))) != MP_OK)
1596 	goto CLEANUP;
1597     }
1598     if ((res = mp_int_copy(TEMP(4), TEMP(1))) != MP_OK)
1599       goto CLEANUP;
1600   }
1601 
1602   if ((res = mp_int_copy(TEMP(1), c)) != MP_OK)
1603     goto CLEANUP;
1604 
1605   /* If the original value of a was negative, flip the output sign. */
1606   if (flips)
1607     (void) mp_int_neg(c, c); /* cannot fail */
1608 
1609   CLEANUP_TEMP();
1610   return res;
1611 }
1612 
mp_int_to_int(mp_int z,mp_small * out)1613 mp_result mp_int_to_int(mp_int z, mp_small *out)
1614 {
1615   mp_usmall uv = 0;
1616   mp_size   uz;
1617   mp_digit *dz;
1618   mp_sign   sz;
1619 
1620   CHECK(z != NULL);
1621 
1622   /* Make sure the value is representable as a small integer */
1623   sz = MP_SIGN(z);
1624   if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1625       mp_int_compare_value(z, MP_SMALL_MIN) < 0)
1626     return MP_RANGE;
1627 
1628   uz = MP_USED(z);
1629   dz = MP_DIGITS(z) + uz - 1;
1630 
1631   while (uz > 0) {
1632     uv <<= MP_DIGIT_BIT/2;
1633     uv = (uv << (MP_DIGIT_BIT/2)) | *dz--;
1634     --uz;
1635   }
1636 
1637   if (out)
1638     *out = (mp_small)((sz == MP_NEG) ? -uv : uv);
1639 
1640   return MP_OK;
1641 }
1642 
mp_int_to_uint(mp_int z,mp_usmall * out)1643 mp_result mp_int_to_uint(mp_int z, mp_usmall *out)
1644 {
1645   mp_usmall uv = 0;
1646   mp_size   uz;
1647   mp_digit *dz;
1648   mp_sign   sz;
1649 
1650   CHECK(z != NULL);
1651 
1652   /* Make sure the value is representable as an unsigned small integer */
1653   sz = MP_SIGN(z);
1654   if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0)
1655     return MP_RANGE;
1656 
1657   uz = MP_USED(z);
1658   dz = MP_DIGITS(z) + uz - 1;
1659 
1660   while (uz > 0) {
1661     uv <<= MP_DIGIT_BIT/2;
1662     uv = (uv << (MP_DIGIT_BIT/2)) | *dz--;
1663     --uz;
1664   }
1665 
1666   if (out)
1667     *out = uv;
1668 
1669   return MP_OK;
1670 }
1671 
mp_int_to_string(mp_int z,mp_size radix,char * str,int limit)1672 mp_result mp_int_to_string(mp_int z, mp_size radix,
1673 			   char *str, int limit)
1674 {
1675   mp_result res;
1676   int       cmp = 0;
1677 
1678   CHECK(z != NULL && str != NULL && limit >= 2);
1679 
1680   if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1681     return MP_RANGE;
1682 
1683   if (CMPZ(z) == 0) {
1684     *str++ = s_val2ch(0, 1);
1685   }
1686   else {
1687     mpz_t tmp;
1688     char  *h, *t;
1689 
1690     if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1691       return res;
1692 
1693     if (MP_SIGN(z) == MP_NEG) {
1694       *str++ = '-';
1695       --limit;
1696     }
1697     h = str;
1698 
1699     /* Generate digits in reverse order until finished or limit reached */
1700     for (/* */; limit > 0; --limit) {
1701       mp_digit d;
1702 
1703       if ((cmp = CMPZ(&tmp)) == 0)
1704 	break;
1705 
1706       d = s_ddiv(&tmp, (mp_digit)radix);
1707       *str++ = s_val2ch(d, 1);
1708     }
1709     t = str - 1;
1710 
1711     /* Put digits back in correct output order */
1712     while (h < t) {
1713       char tc = *h;
1714       *h++ = *t;
1715       *t-- = tc;
1716     }
1717 
1718     mp_int_clear(&tmp);
1719   }
1720 
1721   *str = '\0';
1722   if (cmp == 0)
1723     return MP_OK;
1724   else
1725     return MP_TRUNC;
1726 }
1727 
mp_int_string_len(mp_int z,mp_size radix)1728 mp_result mp_int_string_len(mp_int z, mp_size radix)
1729 {
1730   int  len;
1731 
1732   CHECK(z != NULL);
1733 
1734   if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1735     return MP_RANGE;
1736 
1737   len = s_outlen(z, radix) + 1; /* for terminator */
1738 
1739   /* Allow for sign marker on negatives */
1740   if (MP_SIGN(z) == MP_NEG)
1741     len += 1;
1742 
1743   return len;
1744 }
1745 
1746 /* Read zero-terminated string into z */
mp_int_read_string(mp_int z,mp_size radix,const char * str)1747 mp_result mp_int_read_string(mp_int z, mp_size radix, const char *str)
1748 {
1749   return mp_int_read_cstring(z, radix, str, NULL);
1750 }
1751 
mp_int_read_cstring(mp_int z,mp_size radix,const char * str,char ** end)1752 mp_result mp_int_read_cstring(mp_int z, mp_size radix, const char *str, char **end)
1753 {
1754   int ch;
1755 
1756   CHECK(z != NULL && str != NULL);
1757 
1758   if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1759     return MP_RANGE;
1760 
1761   /* Skip leading whitespace */
1762   while (isspace((int)*str))
1763     ++str;
1764 
1765   /* Handle leading sign tag (+/-, positive default) */
1766   switch (*str) {
1767   case '-':
1768     MP_SIGN(z) = MP_NEG;
1769     ++str;
1770     break;
1771   case '+':
1772     ++str; /* fallthrough */
1773   default:
1774     MP_SIGN(z) = MP_ZPOS;
1775     break;
1776   }
1777 
1778   /* Skip leading zeroes */
1779   while ((ch = s_ch2val(*str, radix)) == 0)
1780     ++str;
1781 
1782   /* Make sure there is enough space for the value */
1783   if (!s_pad(z, s_inlen(strlen(str), radix)))
1784     return MP_MEMORY;
1785 
1786   MP_USED(z) = 1; z->digits[0] = 0;
1787 
1788   while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) {
1789     s_dmul(z, (mp_digit)radix);
1790     s_dadd(z, (mp_digit)ch);
1791     ++str;
1792   }
1793 
1794   CLAMP(z);
1795 
1796   /* Override sign for zero, even if negative specified. */
1797   if (CMPZ(z) == 0)
1798     MP_SIGN(z) = MP_ZPOS;
1799 
1800   if (end != NULL)
1801     *end = (char *)str;
1802 
1803   /* Return a truncation error if the string has unprocessed characters
1804      remaining, so the caller can tell if the whole string was done */
1805   if (*str != '\0')
1806     return MP_TRUNC;
1807   else
1808     return MP_OK;
1809 }
1810 
mp_int_count_bits(mp_int z)1811 mp_result mp_int_count_bits(mp_int z)
1812 {
1813   mp_size  nbits = 0, uz;
1814   mp_digit d;
1815 
1816   CHECK(z != NULL);
1817 
1818   uz = MP_USED(z);
1819   if (uz == 1 && z->digits[0] == 0)
1820     return 1;
1821 
1822   --uz;
1823   nbits = uz * MP_DIGIT_BIT;
1824   d = z->digits[uz];
1825 
1826   while (d != 0) {
1827     d >>= 1;
1828     ++nbits;
1829   }
1830 
1831   return nbits;
1832 }
1833 
mp_int_to_binary(mp_int z,unsigned char * buf,int limit)1834 mp_result mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
1835 {
1836   static const int PAD_FOR_2C = 1;
1837 
1838   mp_result res;
1839   int limpos = limit;
1840 
1841   CHECK(z != NULL && buf != NULL);
1842 
1843   res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
1844 
1845   if (MP_SIGN(z) == MP_NEG)
1846     s_2comp(buf, limpos);
1847 
1848   return res;
1849 }
1850 
mp_int_read_binary(mp_int z,unsigned char * buf,int len)1851 mp_result mp_int_read_binary(mp_int z, unsigned char *buf, int len)
1852 {
1853   mp_size need, i;
1854   unsigned char *tmp;
1855   mp_digit *dz;
1856 
1857   CHECK(z != NULL && buf != NULL && len > 0);
1858 
1859   /* Figure out how many digits are needed to represent this value */
1860   need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1861   if (!s_pad(z, need))
1862     return MP_MEMORY;
1863 
1864   mp_int_zero(z);
1865 
1866   /* If the high-order bit is set, take the 2's complement before reading the
1867      value (it will be restored afterward) */
1868   if (buf[0] >> (CHAR_BIT - 1)) {
1869     MP_SIGN(z) = MP_NEG;
1870     s_2comp(buf, len);
1871   }
1872 
1873   dz = MP_DIGITS(z);
1874   for (tmp = buf, i = len; i > 0; --i, ++tmp) {
1875     s_qmul(z, (mp_size) CHAR_BIT);
1876     *dz |= *tmp;
1877   }
1878 
1879   /* Restore 2's complement if we took it before */
1880   if (MP_SIGN(z) == MP_NEG)
1881     s_2comp(buf, len);
1882 
1883   return MP_OK;
1884 }
1885 
mp_int_binary_len(mp_int z)1886 mp_result mp_int_binary_len(mp_int z)
1887 {
1888   mp_result  res = mp_int_count_bits(z);
1889   int        bytes = mp_int_unsigned_len(z);
1890 
1891   if (res <= 0)
1892     return res;
1893 
1894   bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1895 
1896   /* If the highest-order bit falls exactly on a byte boundary, we need to pad
1897      with an extra byte so that the sign will be read correctly when reading it
1898      back in. */
1899   if (bytes * CHAR_BIT == res)
1900     ++bytes;
1901 
1902   return bytes;
1903 }
1904 
mp_int_to_unsigned(mp_int z,unsigned char * buf,int limit)1905 mp_result mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
1906 {
1907   static const int NO_PADDING = 0;
1908 
1909   CHECK(z != NULL && buf != NULL);
1910 
1911   return s_tobin(z, buf, &limit, NO_PADDING);
1912 }
1913 
mp_int_read_unsigned(mp_int z,unsigned char * buf,int len)1914 mp_result mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
1915 {
1916   mp_size need, i;
1917   unsigned char *tmp;
1918 
1919   CHECK(z != NULL && buf != NULL && len > 0);
1920 
1921   /* Figure out how many digits are needed to represent this value */
1922   need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1923   if (!s_pad(z, need))
1924     return MP_MEMORY;
1925 
1926   mp_int_zero(z);
1927 
1928   for (tmp = buf, i = len; i > 0; --i, ++tmp) {
1929     (void) s_qmul(z, CHAR_BIT);
1930     *MP_DIGITS(z) |= *tmp;
1931   }
1932 
1933   return MP_OK;
1934 }
1935 
mp_int_unsigned_len(mp_int z)1936 mp_result mp_int_unsigned_len(mp_int z)
1937 {
1938   mp_result  res = mp_int_count_bits(z);
1939   int        bytes;
1940 
1941   if (res <= 0)
1942     return res;
1943 
1944   bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1945 
1946   return bytes;
1947 }
1948 
mp_error_string(mp_result res)1949 const char *mp_error_string(mp_result res)
1950 {
1951   int ix;
1952   if (res > 0)
1953     return s_unknown_err;
1954 
1955   res = -res;
1956   for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
1957     ;
1958 
1959   if (s_error_msg[ix] != NULL)
1960     return s_error_msg[ix];
1961   else
1962     return s_unknown_err;
1963 }
1964 
1965 /*------------------------------------------------------------------------*/
1966 /* Private functions for internal use.  These make assumptions.           */
1967 
s_alloc(mp_size num)1968 STATIC mp_digit *s_alloc(mp_size num)
1969 {
1970   mp_digit *out = malloc(num * sizeof(mp_digit));
1971 
1972   assert(out != NULL); /* for debugging */
1973 #if DEBUG > 1
1974   {
1975     mp_digit v = (mp_digit) 0xdeadbeef;
1976     int      ix;
1977 
1978     for (ix = 0; ix < num; ++ix)
1979       out[ix] = v;
1980   }
1981 #endif
1982 
1983   return out;
1984 }
1985 
s_realloc(mp_digit * old,mp_size osize,mp_size nsize)1986 STATIC mp_digit *s_realloc(mp_digit *old, mp_size osize, mp_size nsize)
1987 {
1988 #if DEBUG > 1
1989   mp_digit *new = s_alloc(nsize);
1990   int       ix;
1991 
1992   for (ix = 0; ix < nsize; ++ix)
1993     new[ix] = (mp_digit) 0xdeadbeef;
1994 
1995   memcpy(new, old, osize * sizeof(mp_digit));
1996 #else
1997   mp_digit *new = realloc(old, nsize * sizeof(mp_digit));
1998 
1999   assert(new != NULL); /* for debugging */
2000 #endif
2001   return new;
2002 }
2003 
s_free(void * ptr)2004 STATIC void s_free(void *ptr)
2005 {
2006   free(ptr);
2007 }
2008 
s_pad(mp_int z,mp_size min)2009 STATIC int      s_pad(mp_int z, mp_size min)
2010 {
2011   if (MP_ALLOC(z) < min) {
2012     mp_size nsize = ROUND_PREC(min);
2013     mp_digit *tmp;
2014 
2015     if ((void *)z->digits == (void *)z) {
2016       if ((tmp = s_alloc(nsize)) == NULL)
2017         return 0;
2018 
2019       COPY(MP_DIGITS(z), tmp, MP_USED(z));
2020     }
2021     else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL)
2022       return 0;
2023 
2024     MP_DIGITS(z) = tmp;
2025     MP_ALLOC(z) = nsize;
2026   }
2027 
2028   return 1;
2029 }
2030 
2031 /* Note: This will not work correctly when value == MP_SMALL_MIN */
s_fake(mp_int z,mp_small value,mp_digit vbuf[])2032 STATIC void      s_fake(mp_int z, mp_small value, mp_digit vbuf[])
2033 {
2034   mp_usmall uv = (mp_usmall) (value < 0) ? -value : value;
2035   s_ufake(z, uv, vbuf);
2036   if (value < 0)
2037     z->sign = MP_NEG;
2038 }
2039 
s_ufake(mp_int z,mp_usmall value,mp_digit vbuf[])2040 STATIC void      s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[])
2041 {
2042   mp_size ndig = (mp_size) s_uvpack(value, vbuf);
2043 
2044   z->used = ndig;
2045   z->alloc = MP_VALUE_DIGITS(value);
2046   z->sign = MP_ZPOS;
2047   z->digits = vbuf;
2048 }
2049 
s_cdig(mp_digit * da,mp_digit * db,mp_size len)2050 STATIC int      s_cdig(mp_digit *da, mp_digit *db, mp_size len)
2051 {
2052   mp_digit *dat = da + len - 1, *dbt = db + len - 1;
2053 
2054   for (/* */; len != 0; --len, --dat, --dbt) {
2055     if (*dat > *dbt)
2056       return 1;
2057     else if (*dat < *dbt)
2058       return -1;
2059   }
2060 
2061   return 0;
2062 }
2063 
s_uvpack(mp_usmall uv,mp_digit t[])2064 STATIC int       s_uvpack(mp_usmall uv, mp_digit t[])
2065 {
2066   int ndig = 0;
2067 
2068   if (uv == 0)
2069     t[ndig++] = 0;
2070   else {
2071     while (uv != 0) {
2072       t[ndig++] = (mp_digit) uv;
2073       uv >>= MP_DIGIT_BIT/2;
2074       uv >>= MP_DIGIT_BIT/2;
2075     }
2076   }
2077 
2078   return ndig;
2079 }
2080 
s_ucmp(mp_int a,mp_int b)2081 STATIC int      s_ucmp(mp_int a, mp_int b)
2082 {
2083   mp_size  ua = MP_USED(a), ub = MP_USED(b);
2084 
2085   if (ua > ub)
2086     return 1;
2087   else if (ub > ua)
2088     return -1;
2089   else
2090     return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2091 }
2092 
s_vcmp(mp_int a,mp_small v)2093 STATIC int      s_vcmp(mp_int a, mp_small v)
2094 {
2095   mp_usmall uv = (v < 0) ? -(mp_usmall) v : (mp_usmall) v;
2096   return s_uvcmp(a, uv);
2097 }
2098 
s_uvcmp(mp_int a,mp_usmall uv)2099 STATIC int      s_uvcmp(mp_int a, mp_usmall uv)
2100 {
2101   mpz_t vtmp;
2102   mp_digit vdig[MP_VALUE_DIGITS(uv)];
2103 
2104   s_ufake(&vtmp, uv, vdig);
2105   return s_ucmp(a, &vtmp);
2106 }
2107 
s_uadd(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2108 STATIC mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
2109 		       mp_size size_a, mp_size size_b)
2110 {
2111   mp_size pos;
2112   mp_word w = 0;
2113 
2114   /* Insure that da is the longer of the two to simplify later code */
2115   if (size_b > size_a) {
2116     SWAP(mp_digit *, da, db);
2117     SWAP(mp_size, size_a, size_b);
2118   }
2119 
2120   /* Add corresponding digits until the shorter number runs out */
2121   for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
2122     w = w + (mp_word) *da + (mp_word) *db;
2123     *dc = LOWER_HALF(w);
2124     w = UPPER_HALF(w);
2125   }
2126 
2127   /* Propagate carries as far as necessary */
2128   for (/* */; pos < size_a; ++pos, ++da, ++dc) {
2129     w = w + *da;
2130 
2131     *dc = LOWER_HALF(w);
2132     w = UPPER_HALF(w);
2133   }
2134 
2135   /* Return carry out */
2136   return (mp_digit)w;
2137 }
2138 
s_usub(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2139 STATIC void     s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
2140 		       mp_size size_a, mp_size size_b)
2141 {
2142   mp_size pos;
2143   mp_word w = 0;
2144 
2145   /* We assume that |a| >= |b| so this should definitely hold */
2146   assert(size_a >= size_b);
2147 
2148   /* Subtract corresponding digits and propagate borrow */
2149   for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
2150     w = ((mp_word)MP_DIGIT_MAX + 1 +  /* MP_RADIX */
2151 	 (mp_word)*da) - w - (mp_word)*db;
2152 
2153     *dc = LOWER_HALF(w);
2154     w = (UPPER_HALF(w) == 0);
2155   }
2156 
2157   /* Finish the subtraction for remaining upper digits of da */
2158   for (/* */; pos < size_a; ++pos, ++da, ++dc) {
2159     w = ((mp_word)MP_DIGIT_MAX + 1 +  /* MP_RADIX */
2160 	 (mp_word)*da) - w;
2161 
2162     *dc = LOWER_HALF(w);
2163     w = (UPPER_HALF(w) == 0);
2164   }
2165 
2166   /* If there is a borrow out at the end, it violates the precondition */
2167   assert(w == 0);
2168 }
2169 
s_kmul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2170 STATIC int       s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
2171 			mp_size size_a, mp_size size_b)
2172 {
2173   mp_size  bot_size;
2174 
2175   /* Make sure b is the smaller of the two input values */
2176   if (size_b > size_a) {
2177     SWAP(mp_digit *, da, db);
2178     SWAP(mp_size, size_a, size_b);
2179   }
2180 
2181   /* Insure that the bottom is the larger half in an odd-length split; the code
2182      below relies on this being true.
2183    */
2184   bot_size = (size_a + 1) / 2;
2185 
2186   /* If the values are big enough to bother with recursion, use the Karatsuba
2187      algorithm to compute the product; otherwise use the normal multiplication
2188      algorithm
2189    */
2190   if (multiply_threshold &&
2191       size_a >= multiply_threshold &&
2192       size_b > bot_size) {
2193 
2194     mp_digit *t1, *t2, *t3, carry;
2195 
2196     mp_digit *a_top = da + bot_size;
2197     mp_digit *b_top = db + bot_size;
2198 
2199     mp_size  at_size = size_a - bot_size;
2200     mp_size  bt_size = size_b - bot_size;
2201     mp_size  buf_size = 2 * bot_size;
2202 
2203     /* Do a single allocation for all three temporary buffers needed; each
2204        buffer must be big enough to hold the product of two bottom halves, and
2205        one buffer needs space for the completed product; twice the space is
2206        plenty.
2207      */
2208     if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
2209     t2 = t1 + buf_size;
2210     t3 = t2 + buf_size;
2211     ZERO(t1, 4 * buf_size);
2212 
2213     /* t1 and t2 are initially used as temporaries to compute the inner product
2214        (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2215      */
2216     carry = s_uadd(da, a_top, t1, bot_size, at_size);      /* t1 = a1 + a0 */
2217     t1[bot_size] = carry;
2218 
2219     carry = s_uadd(db, b_top, t2, bot_size, bt_size);      /* t2 = b1 + b0 */
2220     t2[bot_size] = carry;
2221 
2222     (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */
2223 
2224     /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that
2225        we're left with only the pieces we want:  t3 = a1b0 + a0b1
2226      */
2227     ZERO(t1, buf_size);
2228     ZERO(t2, buf_size);
2229     (void) s_kmul(da, db, t1, bot_size, bot_size);     /* t1 = a0 * b0 */
2230     (void) s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */
2231 
2232     /* Subtract out t1 and t2 to get the inner product */
2233     s_usub(t3, t1, t3, buf_size + 2, buf_size);
2234     s_usub(t3, t2, t3, buf_size + 2, buf_size);
2235 
2236     /* Assemble the output value */
2237     COPY(t1, dc, buf_size);
2238     carry = s_uadd(t3, dc + bot_size, dc + bot_size,
2239 		   buf_size + 1, buf_size);
2240     assert(carry == 0);
2241 
2242     carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size,
2243 		   buf_size, buf_size);
2244     assert(carry == 0);
2245 
2246     s_free(t1); /* note t2 and t3 are just internal pointers to t1 */
2247   }
2248   else {
2249     s_umul(da, db, dc, size_a, size_b);
2250   }
2251 
2252   return 1;
2253 }
2254 
s_umul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2255 STATIC void     s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
2256 		       mp_size size_a, mp_size size_b)
2257 {
2258   mp_size a, b;
2259   mp_word w;
2260 
2261   for (a = 0; a < size_a; ++a, ++dc, ++da) {
2262     mp_digit *dct = dc;
2263     mp_digit *dbt = db;
2264 
2265     if (*da == 0)
2266       continue;
2267 
2268     w = 0;
2269     for (b = 0; b < size_b; ++b, ++dbt, ++dct) {
2270       w = (mp_word)*da * (mp_word)*dbt + w + (mp_word)*dct;
2271 
2272       *dct = LOWER_HALF(w);
2273       w = UPPER_HALF(w);
2274     }
2275 
2276     *dct = (mp_digit)w;
2277   }
2278 }
2279 
s_ksqr(mp_digit * da,mp_digit * dc,mp_size size_a)2280 STATIC int       s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2281 {
2282   if (multiply_threshold && size_a > multiply_threshold) {
2283     mp_size  bot_size = (size_a + 1) / 2;
2284     mp_digit *a_top = da + bot_size;
2285     mp_digit *t1, *t2, *t3, carry;
2286     mp_size  at_size = size_a - bot_size;
2287     mp_size  buf_size = 2 * bot_size;
2288 
2289     if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
2290     t2 = t1 + buf_size;
2291     t3 = t2 + buf_size;
2292     ZERO(t1, 4 * buf_size);
2293 
2294     (void) s_ksqr(da, t1, bot_size);    /* t1 = a0 ^ 2 */
2295     (void) s_ksqr(a_top, t2, at_size);  /* t2 = a1 ^ 2 */
2296 
2297     (void) s_kmul(da, a_top, t3, bot_size, at_size);  /* t3 = a0 * a1 */
2298 
2299     /* Quick multiply t3 by 2, shifting left (can't overflow) */
2300     {
2301       int i, top = bot_size + at_size;
2302       mp_word w, save = 0;
2303 
2304       for (i = 0; i < top; ++i) {
2305 	w = t3[i];
2306 	w = (w << 1) | save;
2307 	t3[i] = LOWER_HALF(w);
2308 	save = UPPER_HALF(w);
2309       }
2310       t3[i] = LOWER_HALF(save);
2311     }
2312 
2313     /* Assemble the output value */
2314     COPY(t1, dc, 2 * bot_size);
2315     carry = s_uadd(t3, dc + bot_size, dc + bot_size,
2316 		   buf_size + 1, buf_size);
2317     assert(carry == 0);
2318 
2319     carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size,
2320 		   buf_size, buf_size);
2321     assert(carry == 0);
2322 
2323     s_free(t1); /* note that t2 and t2 are internal pointers only */
2324 
2325   }
2326   else {
2327     s_usqr(da, dc, size_a);
2328   }
2329 
2330   return 1;
2331 }
2332 
s_usqr(mp_digit * da,mp_digit * dc,mp_size size_a)2333 STATIC void      s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2334 {
2335   mp_size i, j;
2336   mp_word w;
2337 
2338   for (i = 0; i < size_a; ++i, dc += 2, ++da) {
2339     mp_digit  *dct = dc, *dat = da;
2340 
2341     if (*da == 0)
2342       continue;
2343 
2344     /* Take care of the first digit, no rollover */
2345     w = (mp_word)*dat * (mp_word)*dat + (mp_word)*dct;
2346     *dct = LOWER_HALF(w);
2347     w = UPPER_HALF(w);
2348     ++dat; ++dct;
2349 
2350     for (j = i + 1; j < size_a; ++j, ++dat, ++dct) {
2351       mp_word  t = (mp_word)*da * (mp_word)*dat;
2352       mp_word  u = w + (mp_word)*dct, ov = 0;
2353 
2354       /* Check if doubling t will overflow a word */
2355       if (HIGH_BIT_SET(t))
2356 	ov = 1;
2357 
2358       w = t + t;
2359 
2360       /* Check if adding u to w will overflow a word */
2361       if (ADD_WILL_OVERFLOW(w, u))
2362 	ov = 1;
2363 
2364       w += u;
2365 
2366       *dct = LOWER_HALF(w);
2367       w = UPPER_HALF(w);
2368       if (ov) {
2369 	w += MP_DIGIT_MAX; /* MP_RADIX */
2370 	++w;
2371       }
2372     }
2373 
2374     w = w + *dct;
2375     *dct = (mp_digit)w;
2376     while ((w = UPPER_HALF(w)) != 0) {
2377       ++dct; w = w + *dct;
2378       *dct = LOWER_HALF(w);
2379     }
2380 
2381     assert(w == 0);
2382   }
2383 }
2384 
s_dadd(mp_int a,mp_digit b)2385 STATIC void      s_dadd(mp_int a, mp_digit b)
2386 {
2387   mp_word w = 0;
2388   mp_digit *da = MP_DIGITS(a);
2389   mp_size ua = MP_USED(a);
2390 
2391   w = (mp_word)*da + b;
2392   *da++ = LOWER_HALF(w);
2393   w = UPPER_HALF(w);
2394 
2395   for (ua -= 1; ua > 0; --ua, ++da) {
2396     w = (mp_word)*da + w;
2397 
2398     *da = LOWER_HALF(w);
2399     w = UPPER_HALF(w);
2400   }
2401 
2402   if (w) {
2403     *da = (mp_digit)w;
2404     MP_USED(a) += 1;
2405   }
2406 }
2407 
s_dmul(mp_int a,mp_digit b)2408 STATIC void      s_dmul(mp_int a, mp_digit b)
2409 {
2410   mp_word w = 0;
2411   mp_digit *da = MP_DIGITS(a);
2412   mp_size ua = MP_USED(a);
2413 
2414   while (ua > 0) {
2415     w = (mp_word)*da * b + w;
2416     *da++ = LOWER_HALF(w);
2417     w = UPPER_HALF(w);
2418     --ua;
2419   }
2420 
2421   if (w) {
2422     *da = (mp_digit)w;
2423     MP_USED(a) += 1;
2424   }
2425 }
2426 
s_dbmul(mp_digit * da,mp_digit b,mp_digit * dc,mp_size size_a)2427 STATIC void      s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a)
2428 {
2429   mp_word  w = 0;
2430 
2431   while (size_a > 0) {
2432     w = (mp_word)*da++ * (mp_word)b + w;
2433 
2434     *dc++ = LOWER_HALF(w);
2435     w = UPPER_HALF(w);
2436     --size_a;
2437   }
2438 
2439   if (w)
2440     *dc = LOWER_HALF(w);
2441 }
2442 
s_ddiv(mp_int a,mp_digit b)2443 STATIC mp_digit  s_ddiv(mp_int a, mp_digit b)
2444 {
2445   mp_word w = 0, qdigit;
2446   mp_size ua = MP_USED(a);
2447   mp_digit *da = MP_DIGITS(a) + ua - 1;
2448 
2449   for (/* */; ua > 0; --ua, --da) {
2450     w = (w << MP_DIGIT_BIT) | *da;
2451 
2452     if (w >= b) {
2453       qdigit = w / b;
2454       w = w % b;
2455     }
2456     else {
2457       qdigit = 0;
2458     }
2459 
2460     *da = (mp_digit)qdigit;
2461   }
2462 
2463   CLAMP(a);
2464   return (mp_digit)w;
2465 }
2466 
s_qdiv(mp_int z,mp_size p2)2467 STATIC void     s_qdiv(mp_int z, mp_size p2)
2468 {
2469   mp_size ndig = p2 / MP_DIGIT_BIT, nbits = p2 % MP_DIGIT_BIT;
2470   mp_size uz = MP_USED(z);
2471 
2472   if (ndig) {
2473     mp_size  mark;
2474     mp_digit *to, *from;
2475 
2476     if (ndig >= uz) {
2477       mp_int_zero(z);
2478       return;
2479     }
2480 
2481     to = MP_DIGITS(z); from = to + ndig;
2482 
2483     for (mark = ndig; mark < uz; ++mark)
2484       *to++ = *from++;
2485 
2486     MP_USED(z) = uz - ndig;
2487   }
2488 
2489   if (nbits) {
2490     mp_digit d = 0, *dz, save;
2491     mp_size  up = MP_DIGIT_BIT - nbits;
2492 
2493     uz = MP_USED(z);
2494     dz = MP_DIGITS(z) + uz - 1;
2495 
2496     for (/* */; uz > 0; --uz, --dz) {
2497       save = *dz;
2498 
2499       *dz = (*dz >> nbits) | (d << up);
2500       d = save;
2501     }
2502 
2503     CLAMP(z);
2504   }
2505 
2506   if (MP_USED(z) == 1 && z->digits[0] == 0)
2507     MP_SIGN(z) = MP_ZPOS;
2508 }
2509 
s_qmod(mp_int z,mp_size p2)2510 STATIC void     s_qmod(mp_int z, mp_size p2)
2511 {
2512   mp_size start = p2 / MP_DIGIT_BIT + 1, rest = p2 % MP_DIGIT_BIT;
2513   mp_size uz = MP_USED(z);
2514   mp_digit mask = (1u << rest) - 1;
2515 
2516   if (start <= uz) {
2517     MP_USED(z) = start;
2518     z->digits[start - 1] &= mask;
2519     CLAMP(z);
2520   }
2521 }
2522 
s_qmul(mp_int z,mp_size p2)2523 STATIC int      s_qmul(mp_int z, mp_size p2)
2524 {
2525   mp_size   uz, need, rest, extra, i;
2526   mp_digit *from, *to, d;
2527 
2528   if (p2 == 0)
2529     return 1;
2530 
2531   uz = MP_USED(z);
2532   need = p2 / MP_DIGIT_BIT; rest = p2 % MP_DIGIT_BIT;
2533 
2534   /* Figure out if we need an extra digit at the top end; this occurs if the
2535      topmost `rest' bits of the high-order digit of z are not zero, meaning
2536      they will be shifted off the end if not preserved */
2537   extra = 0;
2538   if (rest != 0) {
2539     mp_digit *dz = MP_DIGITS(z) + uz - 1;
2540 
2541     if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
2542       extra = 1;
2543   }
2544 
2545   if (!s_pad(z, uz + need + extra))
2546     return 0;
2547 
2548   /* If we need to shift by whole digits, do that in one pass, then
2549      to back and shift by partial digits.
2550    */
2551   if (need > 0) {
2552     from = MP_DIGITS(z) + uz - 1;
2553     to = from + need;
2554 
2555     for (i = 0; i < uz; ++i)
2556       *to-- = *from--;
2557 
2558     ZERO(MP_DIGITS(z), need);
2559     uz += need;
2560   }
2561 
2562   if (rest) {
2563     d = 0;
2564     for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) {
2565       mp_digit save = *from;
2566 
2567       *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2568       d = save;
2569     }
2570 
2571     d >>= (MP_DIGIT_BIT - rest);
2572     if (d != 0) {
2573       *from = d;
2574       uz += extra;
2575     }
2576   }
2577 
2578   MP_USED(z) = uz;
2579   CLAMP(z);
2580 
2581   return 1;
2582 }
2583 
2584 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2585    The sign of the result is always zero/positive.
2586  */
s_qsub(mp_int z,mp_size p2)2587 STATIC int       s_qsub(mp_int z, mp_size p2)
2588 {
2589   mp_digit hi = (1 << (p2 % MP_DIGIT_BIT)), *zp;
2590   mp_size  tdig = (p2 / MP_DIGIT_BIT), pos;
2591   mp_word  w = 0;
2592 
2593   if (!s_pad(z, tdig + 1))
2594     return 0;
2595 
2596   for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) {
2597     w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word)*zp;
2598 
2599     *zp = LOWER_HALF(w);
2600     w = UPPER_HALF(w) ? 0 : 1;
2601   }
2602 
2603   w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word)*zp;
2604   *zp = LOWER_HALF(w);
2605 
2606   assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2607 
2608   MP_SIGN(z) = MP_ZPOS;
2609   CLAMP(z);
2610 
2611   return 1;
2612 }
2613 
s_dp2k(mp_int z)2614 STATIC int      s_dp2k(mp_int z)
2615 {
2616   int       k = 0;
2617   mp_digit *dp = MP_DIGITS(z), d;
2618 
2619   if (MP_USED(z) == 1 && *dp == 0)
2620     return 1;
2621 
2622   while (*dp == 0) {
2623     k += MP_DIGIT_BIT;
2624     ++dp;
2625   }
2626 
2627   d = *dp;
2628   while ((d & 1) == 0) {
2629     d >>= 1;
2630     ++k;
2631   }
2632 
2633   return k;
2634 }
2635 
s_isp2(mp_int z)2636 STATIC int       s_isp2(mp_int z)
2637 {
2638   mp_size uz = MP_USED(z), k = 0;
2639   mp_digit *dz = MP_DIGITS(z), d;
2640 
2641   while (uz > 1) {
2642     if (*dz++ != 0)
2643       return -1;
2644     k += MP_DIGIT_BIT;
2645     --uz;
2646   }
2647 
2648   d = *dz;
2649   while (d > 1) {
2650     if (d & 1)
2651       return -1;
2652     ++k; d >>= 1;
2653   }
2654 
2655   return (int) k;
2656 }
2657 
s_2expt(mp_int z,mp_small k)2658 STATIC int       s_2expt(mp_int z, mp_small k)
2659 {
2660   mp_size  ndig, rest;
2661   mp_digit *dz;
2662 
2663   ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
2664   rest = k % MP_DIGIT_BIT;
2665 
2666   if (!s_pad(z, ndig))
2667     return 0;
2668 
2669   dz = MP_DIGITS(z);
2670   ZERO(dz, ndig);
2671   *(dz + ndig - 1) = (1 << rest);
2672   MP_USED(z) = ndig;
2673 
2674   return 1;
2675 }
2676 
s_norm(mp_int a,mp_int b)2677 STATIC int      s_norm(mp_int a, mp_int b)
2678 {
2679   mp_digit d = b->digits[MP_USED(b) - 1];
2680   int k = 0;
2681 
2682   while (d < (1u << (mp_digit)(MP_DIGIT_BIT - 1))) { /* d < (MP_RADIX / 2) */
2683     d <<= 1;
2684     ++k;
2685   }
2686 
2687   /* These multiplications can't fail */
2688   if (k != 0) {
2689     (void) s_qmul(a, (mp_size) k);
2690     (void) s_qmul(b, (mp_size) k);
2691   }
2692 
2693   return k;
2694 }
2695 
s_brmu(mp_int z,mp_int m)2696 STATIC mp_result s_brmu(mp_int z, mp_int m)
2697 {
2698   mp_size um = MP_USED(m) * 2;
2699 
2700   if (!s_pad(z, um))
2701     return MP_MEMORY;
2702 
2703   s_2expt(z, MP_DIGIT_BIT * um);
2704   return mp_int_div(z, m, z, NULL);
2705 }
2706 
s_reduce(mp_int x,mp_int m,mp_int mu,mp_int q1,mp_int q2)2707 STATIC int       s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
2708 {
2709   mp_size   um = MP_USED(m), umb_p1, umb_m1;
2710 
2711   umb_p1 = (um + 1) * MP_DIGIT_BIT;
2712   umb_m1 = (um - 1) * MP_DIGIT_BIT;
2713 
2714   if (mp_int_copy(x, q1) != MP_OK)
2715     return 0;
2716 
2717   /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
2718   s_qdiv(q1, umb_m1);
2719   UMUL(q1, mu, q2);
2720   s_qdiv(q2, umb_p1);
2721 
2722   /* Set x = x mod b^(k+1) */
2723   s_qmod(x, umb_p1);
2724 
2725   /* Now, q is a guess for the quotient a / m.
2726      Compute x - q * m mod b^(k+1), replacing x.  This may be off
2727      by a factor of 2m, but no more than that.
2728    */
2729   UMUL(q2, m, q1);
2730   s_qmod(q1, umb_p1);
2731   (void) mp_int_sub(x, q1, x); /* can't fail */
2732 
2733   /* The result may be < 0; if it is, add b^(k+1) to pin it in the proper
2734      range. */
2735   if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
2736     return 0;
2737 
2738   /* If x > m, we need to back it off until it is in range.  This will be
2739      required at most twice.  */
2740   if (mp_int_compare(x, m) >= 0) {
2741     (void) mp_int_sub(x, m, x);
2742     if (mp_int_compare(x, m) >= 0)
2743       (void) mp_int_sub(x, m, x);
2744   }
2745 
2746   /* At this point, x has been properly reduced. */
2747   return 1;
2748 }
2749 
2750 /* Perform modular exponentiation using Barrett's method, where mu is the
2751    reduction constant for m.  Assumes a < m, b > 0. */
s_embar(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)2752 STATIC mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
2753 {
2754   mp_digit  *db, *dbt, umu, d;
2755   mp_result res;
2756   DECLARE_TEMP(3);
2757 
2758   umu = MP_USED(mu); db = MP_DIGITS(b); dbt = db + MP_USED(b) - 1;
2759 
2760   while (last__ < 3) {
2761     SETUP(mp_int_init_size(LAST_TEMP(), 4 * umu));
2762     ZERO(MP_DIGITS(TEMP(last__ - 1)), MP_ALLOC(TEMP(last__ - 1)));
2763   }
2764 
2765   (void) mp_int_set_value(c, 1);
2766 
2767   /* Take care of low-order digits */
2768   while (db < dbt) {
2769     int      i;
2770 
2771     for (d = *db, i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) {
2772       if (d & 1) {
2773 	/* The use of a second temporary avoids allocation */
2774 	UMUL(c, a, TEMP(0));
2775 	if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2776 	  res = MP_MEMORY; goto CLEANUP;
2777 	}
2778 	mp_int_copy(TEMP(0), c);
2779       }
2780 
2781 
2782       USQR(a, TEMP(0));
2783       assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2784       if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2785 	res = MP_MEMORY; goto CLEANUP;
2786       }
2787       assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2788       mp_int_copy(TEMP(0), a);
2789     }
2790 
2791     ++db;
2792   }
2793 
2794   /* Take care of highest-order digit */
2795   d = *dbt;
2796   for (;;) {
2797     if (d & 1) {
2798       UMUL(c, a, TEMP(0));
2799       if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2800 	res = MP_MEMORY; goto CLEANUP;
2801       }
2802       mp_int_copy(TEMP(0), c);
2803     }
2804 
2805     d >>= 1;
2806     if (!d) break;
2807 
2808     USQR(a, TEMP(0));
2809     if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2810       res = MP_MEMORY; goto CLEANUP;
2811     }
2812     (void) mp_int_copy(TEMP(0), a);
2813   }
2814 
2815   CLEANUP_TEMP();
2816   return res;
2817 }
2818 
2819 /* Division of nonnegative integers
2820 
2821    This function implements division algorithm for unsigned multi-precision
2822    integers. The algorithm is based on Algorithm D from Knuth's "The Art of
2823    Computer Programming", 3rd ed. 1998, pg 272-273.
2824 
2825    We diverge from Knuth's algorithm in that we do not perform the subtraction
2826    from the remainder until we have determined that we have the correct
2827    quotient digit. This makes our algorithm less efficient that Knuth because
2828    we might have to perform multiple multiplication and comparison steps before
2829    the subtraction. The advantage is that it is easy to implement and ensure
2830    correctness without worrying about underflow from the subtraction.
2831 
2832    inputs: u   a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
2833            v   a n   digit integer in base b (b is 2^MP_DIGIT_BIT)
2834            n >= 1
2835            m >= 0
2836   outputs: u / v stored in u
2837            u % v stored in v
2838  */
s_udiv_knuth(mp_int u,mp_int v)2839 STATIC mp_result s_udiv_knuth(mp_int u, mp_int v) {
2840   mpz_t q, r, t;
2841   mp_result
2842   res = MP_OK;
2843   int k,j;
2844   mp_size m,n;
2845 
2846   /* Force signs to positive */
2847   MP_SIGN(u) = MP_ZPOS;
2848   MP_SIGN(v) = MP_ZPOS;
2849 
2850   /* Use simple division algorithm when v is only one digit long */
2851   if (MP_USED(v) == 1) {
2852     mp_digit d, rem;
2853     d   = v->digits[0];
2854     rem = s_ddiv(u, d);
2855     mp_int_set_value(v, rem);
2856     return MP_OK;
2857   }
2858 
2859   /* Algorithm D
2860 
2861      The n and m variables are defined as used by Knuth.
2862      u is an n digit number with digits u_{n-1}..u_0.
2863      v is an n+m digit number with digits from v_{m+n-1}..v_0.
2864      We require that n > 1 and m >= 0
2865    */
2866   n = MP_USED(v);
2867   m = MP_USED(u) - n;
2868   assert(n > 1);
2869   assert(m >= 0);
2870 
2871   /* D1: Normalize.
2872      The normalization step provides the necessary condition for Theorem B,
2873      which states that the quotient estimate for q_j, call it qhat
2874 
2875        qhat = u_{j+n}u_{j+n-1} / v_{n-1}
2876 
2877      is bounded by
2878 
2879       qhat - 2 <= q_j <= qhat.
2880 
2881      That is, qhat is always greater than the actual quotient digit q,
2882      and it is never more than two larger than the actual quotient digit.
2883    */
2884   k = s_norm(u, v);
2885 
2886   /* Extend size of u by one if needed.
2887 
2888      The algorithm begins with a value of u that has one more digit of input.
2889      The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. If the
2890      multiplication did not increase the number of digits of u, we need to add
2891      a leading zero here.
2892    */
2893   if (k == 0 || MP_USED(u) != m + n + 1) {
2894     if (!s_pad(u, m+n+1))
2895       return MP_MEMORY;
2896     u->digits[m+n] = 0;
2897     u->used = m+n+1;
2898   }
2899 
2900   /* Add a leading 0 to v.
2901 
2902      The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0.  We need to
2903      add the leading zero to v here to ensure that the multiplication will
2904      produce the full n+1 digit result.
2905    */
2906   if (!s_pad(v, n+1)) return MP_MEMORY; v->digits[n] = 0;
2907 
2908   /* Initialize temporary variables q and t.
2909      q allocates space for m+1 digits to store the quotient digits
2910      t allocates space for n+1 digits to hold the result of q_j*v
2911    */
2912   if ((res = mp_int_init_size(&q, m + 1)) != MP_OK) return res;
2913   if ((res = mp_int_init_size(&t, n + 1)) != MP_OK) goto CLEANUP;
2914 
2915   /* D2: Initialize j */
2916   j = m;
2917   r.digits = MP_DIGITS(u) + j;  /* The contents of r are shared with u */
2918   r.used   = n + 1;
2919   r.sign   = MP_ZPOS;
2920   r.alloc  = MP_ALLOC(u);
2921   ZERO(t.digits, t.alloc);
2922 
2923   /* Calculate the m+1 digits of the quotient result */
2924   for (; j >= 0; j--) {
2925     /* D3: Calculate q' */
2926     /* r->digits is aligned to position j of the number u */
2927     mp_word pfx, qhat;
2928     pfx   = r.digits[n];
2929     pfx <<= MP_DIGIT_BIT / 2;
2930     pfx <<= MP_DIGIT_BIT / 2;
2931     pfx |= r.digits[n-1]; /* pfx = u_{j+n}{j+n-1} */
2932 
2933     qhat = pfx / v->digits[n-1];
2934     /* Check to see if qhat > b, and decrease qhat if so.
2935        Theorem B guarantess that qhat is at most 2 larger than the
2936        actual value, so it is possible that qhat is greater than
2937        the maximum value that will fit in a digit */
2938     if (qhat > MP_DIGIT_MAX)
2939       qhat = MP_DIGIT_MAX;
2940 
2941     /* D4,D5,D6: Multiply qhat * v and test for a correct value of q
2942 
2943        We proceed a bit different than the way described by Knuth. This way is
2944        simpler but less efficent. Instead of doing the multiply and subtract
2945        then checking for underflow, we first do the multiply of qhat * v and
2946        see if it is larger than the current remainder r. If it is larger, we
2947        decrease qhat by one and try again. We may need to decrease qhat one
2948        more time before we get a value that is smaller than r.
2949 
2950        This way is less efficent than Knuth becuase we do more multiplies, but
2951        we do not need to worry about underflow this way.
2952      */
2953     /* t = qhat * v */
2954     s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1); t.used = n + 1;
2955     CLAMP(&t);
2956 
2957     /* Clamp r for the comparison. Comparisons do not like leading zeros. */
2958     CLAMP(&r);
2959     if (s_ucmp(&t, &r) > 0) {   /* would the remainder be negative? */
2960       qhat -= 1;   /* try a smaller q */
2961       s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1);
2962       t.used = n + 1; CLAMP(&t);
2963       if (s_ucmp(&t, &r) > 0) { /* would the remainder be negative? */
2964         assert(qhat > 0);
2965         qhat -= 1; /* try a smaller q */
2966         s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1);
2967         t.used = n + 1; CLAMP(&t);
2968       }
2969       assert(s_ucmp(&t, &r) <=  0 && "The mathematics failed us.");
2970     }
2971     /* Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be n+1
2972        digits long. */
2973     r.used = n + 1;
2974 
2975     /* D4: Multiply and subtract
2976 
2977        Note: The multiply was completed above so we only need to subtract here.
2978      */
2979     s_usub(r.digits, t.digits, r.digits, r.used, t.used);
2980 
2981     /* D5: Test remainder
2982 
2983        Note: Not needed because we always check that qhat is the correct value
2984              before performing the subtract.  Value cast to mp_digit to prevent
2985              warning, qhat has been clamped to MP_DIGIT_MAX
2986      */
2987     q.digits[j] = (mp_digit)qhat;
2988 
2989     /* D6: Add back
2990        Note: Not needed because we always check that qhat is the correct value
2991              before performing the subtract.
2992      */
2993 
2994     /* D7: Loop on j */
2995     r.digits--;
2996     ZERO(t.digits, t.alloc);
2997   }
2998 
2999   /* Get rid of leading zeros in q */
3000   q.used = m + 1;
3001   CLAMP(&q);
3002 
3003   /* Denormalize the remainder */
3004   CLAMP(u); /* use u here because the r.digits pointer is off-by-one */
3005   if (k != 0)
3006     s_qdiv(u, k);
3007 
3008   mp_int_copy(u, v);  /* ok:  0 <= r < v */
3009   mp_int_copy(&q, u); /* ok:  q <= u     */
3010 
3011   mp_int_clear(&t);
3012  CLEANUP:
3013   mp_int_clear(&q);
3014   return res;
3015 }
3016 
s_outlen(mp_int z,mp_size r)3017 STATIC int       s_outlen(mp_int z, mp_size r)
3018 {
3019   mp_result bits;
3020   double raw;
3021 
3022   assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
3023 
3024   bits = mp_int_count_bits(z);
3025   raw = (double)bits * s_log2[r];
3026 
3027   return (int)(raw + 0.999999);
3028 }
3029 
s_inlen(int len,mp_size r)3030 STATIC mp_size   s_inlen(int len, mp_size r)
3031 {
3032   double  raw = (double)len / s_log2[r];
3033   mp_size bits = (mp_size)(raw + 0.5);
3034 
3035   return (mp_size)((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
3036 }
3037 
s_ch2val(char c,int r)3038 STATIC int       s_ch2val(char c, int r)
3039 {
3040   int out;
3041 
3042   if (isdigit((unsigned char) c))
3043     out = c - '0';
3044   else if (r > 10 && isalpha((unsigned char) c))
3045     out = toupper(c) - 'A' + 10;
3046   else
3047     return -1;
3048 
3049   return (out >= r) ? -1 : out;
3050 }
3051 
s_val2ch(int v,int caps)3052 STATIC char      s_val2ch(int v, int caps)
3053 {
3054   assert(v >= 0);
3055 
3056   if (v < 10)
3057     return v + '0';
3058   else {
3059     char out = (v - 10) + 'a';
3060 
3061     if (caps)
3062       return toupper(out);
3063     else
3064       return out;
3065   }
3066 }
3067 
s_2comp(unsigned char * buf,int len)3068 STATIC void      s_2comp(unsigned char *buf, int len)
3069 {
3070   int i;
3071   unsigned short s = 1;
3072 
3073   for (i = len - 1; i >= 0; --i) {
3074     unsigned char c = ~buf[i];
3075 
3076     s = c + s;
3077     c = s & UCHAR_MAX;
3078     s >>= CHAR_BIT;
3079 
3080     buf[i] = c;
3081   }
3082 
3083   /* last carry out is ignored */
3084 }
3085 
s_tobin(mp_int z,unsigned char * buf,int * limpos,int pad)3086 STATIC mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3087 {
3088   mp_size uz;
3089   mp_digit *dz;
3090   int pos = 0, limit = *limpos;
3091 
3092   uz = MP_USED(z); dz = MP_DIGITS(z);
3093   while (uz > 0 && pos < limit) {
3094     mp_digit d = *dz++;
3095     int i;
3096 
3097     for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) {
3098       buf[pos++] = (unsigned char)d;
3099       d >>= CHAR_BIT;
3100 
3101       /* Don't write leading zeroes */
3102       if (d == 0 && uz == 1)
3103 	i = 0; /* exit loop without signaling truncation */
3104     }
3105 
3106     /* Detect truncation (loop exited with pos >= limit) */
3107     if (i > 0) break;
3108 
3109     --uz;
3110   }
3111 
3112   if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) {
3113     if (pos < limit)
3114       buf[pos++] = 0;
3115     else
3116       uz = 1;
3117   }
3118 
3119   /* Digits are in reverse order, fix that */
3120   REV(unsigned char, buf, pos);
3121 
3122   /* Return the number of bytes actually written */
3123   *limpos = pos;
3124 
3125   return (uz == 0) ? MP_OK : MP_TRUNC;
3126 }
3127 
3128 #if DEBUG
s_print(char * tag,mp_int z)3129 void      s_print(char *tag, mp_int z)
3130 {
3131   int  i;
3132 
3133   fprintf(stderr, "%s: %c ", tag,
3134 	  (MP_SIGN(z) == MP_NEG) ? '-' : '+');
3135 
3136   for (i = MP_USED(z) - 1; i >= 0; --i)
3137     fprintf(stderr, "%0*X", (int)(MP_DIGIT_BIT / 4), z->digits[i]);
3138 
3139   fputc('\n', stderr);
3140 
3141 }
3142 
s_print_buf(char * tag,mp_digit * buf,mp_size num)3143 void      s_print_buf(char *tag, mp_digit *buf, mp_size num)
3144 {
3145   int i;
3146 
3147   fprintf(stderr, "%s: ", tag);
3148 
3149   for (i = num - 1; i >= 0; --i)
3150     fprintf(stderr, "%0*X", (int)(MP_DIGIT_BIT / 4), buf[i]);
3151 
3152   fputc('\n', stderr);
3153 }
3154 #endif
3155 
3156 /* Here there be dragons */
3157