1 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
2  * All rights reserved.
3  *
4  * This package is an SSL implementation written
5  * by Eric Young (eay@cryptsoft.com).
6  * The implementation was written so as to conform with Netscapes SSL.
7  *
8  * This library is free for commercial and non-commercial use as long as
9  * the following conditions are aheared to.  The following conditions
10  * apply to all code found in this distribution, be it the RC4, RSA,
11  * lhash, DES, etc., code; not just the SSL code.  The SSL documentation
12  * included with this distribution is covered by the same copyright terms
13  * except that the holder is Tim Hudson (tjh@cryptsoft.com).
14  *
15  * Copyright remains Eric Young's, and as such any Copyright notices in
16  * the code are not to be removed.
17  * If this package is used in a product, Eric Young should be given attribution
18  * as the author of the parts of the library used.
19  * This can be in the form of a textual message at program startup or
20  * in documentation (online or textual) provided with the package.
21  *
22  * Redistribution and use in source and binary forms, with or without
23  * modification, are permitted provided that the following conditions
24  * are met:
25  * 1. Redistributions of source code must retain the copyright
26  *    notice, this list of conditions and the following disclaimer.
27  * 2. Redistributions in binary form must reproduce the above copyright
28  *    notice, this list of conditions and the following disclaimer in the
29  *    documentation and/or other materials provided with the distribution.
30  * 3. All advertising materials mentioning features or use of this software
31  *    must display the following acknowledgement:
32  *    "This product includes cryptographic software written by
33  *     Eric Young (eay@cryptsoft.com)"
34  *    The word 'cryptographic' can be left out if the rouines from the library
35  *    being used are not cryptographic related :-).
36  * 4. If you include any Windows specific code (or a derivative thereof) from
37  *    the apps directory (application code) you must include an acknowledgement:
38  *    "This product includes software written by Tim Hudson (tjh@cryptsoft.com)"
39  *
40  * THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND
41  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
42  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
43  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
44  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
45  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
46  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
47  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
48  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
49  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
50  * SUCH DAMAGE.
51  *
52  * The licence and distribution terms for any publically available version or
53  * derivative of this code cannot be changed.  i.e. this code cannot simply be
54  * copied and put under another distribution licence
55  * [including the GNU Public Licence.]
56  */
57 /* ====================================================================
58  * Copyright 2002 Sun Microsystems, Inc. ALL RIGHTS RESERVED.
59  *
60  * Portions of the attached software ("Contribution") are developed by
61  * SUN MICROSYSTEMS, INC., and are contributed to the OpenSSL project.
62  *
63  * The Contribution is licensed pursuant to the Eric Young open source
64  * license provided above.
65  *
66  * The binary polynomial arithmetic software is originally written by
67  * Sheueling Chang Shantz and Douglas Stebila of Sun Microsystems
68  * Laboratories. */
69 
70 /* Per C99, various stdint.h and inttypes.h macros (the latter used by bn.h) are
71  * unavailable in C++ unless some macros are defined. C++11 overruled this
72  * decision, but older Android NDKs still require it. */
73 #if !defined(__STDC_CONSTANT_MACROS)
74 #define __STDC_CONSTANT_MACROS
75 #endif
76 #if !defined(__STDC_FORMAT_MACROS)
77 #define __STDC_FORMAT_MACROS
78 #endif
79 
80 #include <assert.h>
81 #include <errno.h>
82 #include <limits.h>
83 #include <stdio.h>
84 #include <string.h>
85 
86 #include <utility>
87 
88 #include <openssl/bn.h>
89 #include <openssl/bytestring.h>
90 #include <openssl/crypto.h>
91 #include <openssl/err.h>
92 #include <openssl/mem.h>
93 
94 #include "../internal.h"
95 #include "../test/file_test.h"
96 #include "../test/test_util.h"
97 
98 
HexToBIGNUM(bssl::UniquePtr<BIGNUM> * out,const char * in)99 static int HexToBIGNUM(bssl::UniquePtr<BIGNUM> *out, const char *in) {
100   BIGNUM *raw = NULL;
101   int ret = BN_hex2bn(&raw, in);
102   out->reset(raw);
103   return ret;
104 }
105 
GetBIGNUM(FileTest * t,const char * attribute)106 static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attribute) {
107   std::string hex;
108   if (!t->GetAttribute(&hex, attribute)) {
109     return nullptr;
110   }
111 
112   bssl::UniquePtr<BIGNUM> ret;
113   if (HexToBIGNUM(&ret, hex.c_str()) != static_cast<int>(hex.size())) {
114     t->PrintLine("Could not decode '%s'.", hex.c_str());
115     return nullptr;
116   }
117   return ret;
118 }
119 
GetInt(FileTest * t,int * out,const char * attribute)120 static bool GetInt(FileTest *t, int *out, const char *attribute) {
121   bssl::UniquePtr<BIGNUM> ret = GetBIGNUM(t, attribute);
122   if (!ret) {
123     return false;
124   }
125 
126   BN_ULONG word = BN_get_word(ret.get());
127   if (word > INT_MAX) {
128     return false;
129   }
130 
131   *out = static_cast<int>(word);
132   return true;
133 }
134 
ExpectBIGNUMsEqual(FileTest * t,const char * operation,const BIGNUM * expected,const BIGNUM * actual)135 static bool ExpectBIGNUMsEqual(FileTest *t, const char *operation,
136                                const BIGNUM *expected, const BIGNUM *actual) {
137   if (BN_cmp(expected, actual) == 0) {
138     return true;
139   }
140 
141   bssl::UniquePtr<char> expected_str(BN_bn2hex(expected));
142   bssl::UniquePtr<char> actual_str(BN_bn2hex(actual));
143   if (!expected_str || !actual_str) {
144     return false;
145   }
146 
147   t->PrintLine("Got %s =", operation);
148   t->PrintLine("\t%s", actual_str.get());
149   t->PrintLine("wanted:");
150   t->PrintLine("\t%s", expected_str.get());
151   return false;
152 }
153 
TestSum(FileTest * t,BN_CTX * ctx)154 static bool TestSum(FileTest *t, BN_CTX *ctx) {
155   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
156   bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
157   bssl::UniquePtr<BIGNUM> sum = GetBIGNUM(t, "Sum");
158   if (!a || !b || !sum) {
159     return false;
160   }
161 
162   bssl::UniquePtr<BIGNUM> ret(BN_new());
163   if (!ret ||
164       !BN_add(ret.get(), a.get(), b.get()) ||
165       !ExpectBIGNUMsEqual(t, "A + B", sum.get(), ret.get()) ||
166       !BN_sub(ret.get(), sum.get(), a.get()) ||
167       !ExpectBIGNUMsEqual(t, "Sum - A", b.get(), ret.get()) ||
168       !BN_sub(ret.get(), sum.get(), b.get()) ||
169       !ExpectBIGNUMsEqual(t, "Sum - B", a.get(), ret.get())) {
170     return false;
171   }
172 
173   // Test that the functions work when |r| and |a| point to the same |BIGNUM|,
174   // or when |r| and |b| point to the same |BIGNUM|. TODO: Test the case where
175   // all of |r|, |a|, and |b| point to the same |BIGNUM|.
176   if (!BN_copy(ret.get(), a.get()) ||
177       !BN_add(ret.get(), ret.get(), b.get()) ||
178       !ExpectBIGNUMsEqual(t, "A + B (r is a)", sum.get(), ret.get()) ||
179       !BN_copy(ret.get(), b.get()) ||
180       !BN_add(ret.get(), a.get(), ret.get()) ||
181       !ExpectBIGNUMsEqual(t, "A + B (r is b)", sum.get(), ret.get()) ||
182       !BN_copy(ret.get(), sum.get()) ||
183       !BN_sub(ret.get(), ret.get(), a.get()) ||
184       !ExpectBIGNUMsEqual(t, "Sum - A (r is a)", b.get(), ret.get()) ||
185       !BN_copy(ret.get(), a.get()) ||
186       !BN_sub(ret.get(), sum.get(), ret.get()) ||
187       !ExpectBIGNUMsEqual(t, "Sum - A (r is b)", b.get(), ret.get()) ||
188       !BN_copy(ret.get(), sum.get()) ||
189       !BN_sub(ret.get(), ret.get(), b.get()) ||
190       !ExpectBIGNUMsEqual(t, "Sum - B (r is a)", a.get(), ret.get()) ||
191       !BN_copy(ret.get(), b.get()) ||
192       !BN_sub(ret.get(), sum.get(), ret.get()) ||
193       !ExpectBIGNUMsEqual(t, "Sum - B (r is b)", a.get(), ret.get())) {
194     return false;
195   }
196 
197   // Test |BN_uadd| and |BN_usub| with the prerequisites they are documented as
198   // having. Note that these functions are frequently used when the
199   // prerequisites don't hold. In those cases, they are supposed to work as if
200   // the prerequisite hold, but we don't test that yet. TODO: test that.
201   if (!BN_is_negative(a.get()) &&
202       !BN_is_negative(b.get()) && BN_cmp(a.get(), b.get()) >= 0) {
203     if (!BN_uadd(ret.get(), a.get(), b.get()) ||
204         !ExpectBIGNUMsEqual(t, "A +u B", sum.get(), ret.get()) ||
205         !BN_usub(ret.get(), sum.get(), a.get()) ||
206         !ExpectBIGNUMsEqual(t, "Sum -u A", b.get(), ret.get()) ||
207         !BN_usub(ret.get(), sum.get(), b.get()) ||
208         !ExpectBIGNUMsEqual(t, "Sum -u B", a.get(), ret.get())) {
209       return false;
210     }
211 
212     // Test that the functions work when |r| and |a| point to the same |BIGNUM|,
213     // or when |r| and |b| point to the same |BIGNUM|. TODO: Test the case where
214     // all of |r|, |a|, and |b| point to the same |BIGNUM|.
215     if (!BN_copy(ret.get(), a.get()) ||
216         !BN_uadd(ret.get(), ret.get(), b.get()) ||
217         !ExpectBIGNUMsEqual(t, "A +u B (r is a)", sum.get(), ret.get()) ||
218         !BN_copy(ret.get(), b.get()) ||
219         !BN_uadd(ret.get(), a.get(), ret.get()) ||
220         !ExpectBIGNUMsEqual(t, "A +u B (r is b)", sum.get(), ret.get()) ||
221         !BN_copy(ret.get(), sum.get()) ||
222         !BN_usub(ret.get(), ret.get(), a.get()) ||
223         !ExpectBIGNUMsEqual(t, "Sum -u A (r is a)", b.get(), ret.get()) ||
224         !BN_copy(ret.get(), a.get()) ||
225         !BN_usub(ret.get(), sum.get(), ret.get()) ||
226         !ExpectBIGNUMsEqual(t, "Sum -u A (r is b)", b.get(), ret.get()) ||
227         !BN_copy(ret.get(), sum.get()) ||
228         !BN_usub(ret.get(), ret.get(), b.get()) ||
229         !ExpectBIGNUMsEqual(t, "Sum -u B (r is a)", a.get(), ret.get()) ||
230         !BN_copy(ret.get(), b.get()) ||
231         !BN_usub(ret.get(), sum.get(), ret.get()) ||
232         !ExpectBIGNUMsEqual(t, "Sum -u B (r is b)", a.get(), ret.get())) {
233       return false;
234     }
235   }
236 
237   // Test with |BN_add_word| and |BN_sub_word| if |b| is small enough.
238   BN_ULONG b_word = BN_get_word(b.get());
239   if (!BN_is_negative(b.get()) && b_word != (BN_ULONG)-1) {
240     if (!BN_copy(ret.get(), a.get()) ||
241         !BN_add_word(ret.get(), b_word) ||
242         !ExpectBIGNUMsEqual(t, "A + B (word)", sum.get(), ret.get()) ||
243         !BN_copy(ret.get(), sum.get()) ||
244         !BN_sub_word(ret.get(), b_word) ||
245         !ExpectBIGNUMsEqual(t, "Sum - B (word)", a.get(), ret.get())) {
246       return false;
247     }
248   }
249 
250   return true;
251 }
252 
TestLShift1(FileTest * t,BN_CTX * ctx)253 static bool TestLShift1(FileTest *t, BN_CTX *ctx) {
254   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
255   bssl::UniquePtr<BIGNUM> lshift1 = GetBIGNUM(t, "LShift1");
256   bssl::UniquePtr<BIGNUM> zero(BN_new());
257   if (!a || !lshift1 || !zero) {
258     return false;
259   }
260 
261   BN_zero(zero.get());
262 
263   bssl::UniquePtr<BIGNUM> ret(BN_new()), two(BN_new()), remainder(BN_new());
264   if (!ret || !two || !remainder ||
265       !BN_set_word(two.get(), 2) ||
266       !BN_add(ret.get(), a.get(), a.get()) ||
267       !ExpectBIGNUMsEqual(t, "A + A", lshift1.get(), ret.get()) ||
268       !BN_mul(ret.get(), a.get(), two.get(), ctx) ||
269       !ExpectBIGNUMsEqual(t, "A * 2", lshift1.get(), ret.get()) ||
270       !BN_div(ret.get(), remainder.get(), lshift1.get(), two.get(), ctx) ||
271       !ExpectBIGNUMsEqual(t, "LShift1 / 2", a.get(), ret.get()) ||
272       !ExpectBIGNUMsEqual(t, "LShift1 % 2", zero.get(), remainder.get()) ||
273       !BN_lshift1(ret.get(), a.get()) ||
274       !ExpectBIGNUMsEqual(t, "A << 1", lshift1.get(), ret.get()) ||
275       !BN_rshift1(ret.get(), lshift1.get()) ||
276       !ExpectBIGNUMsEqual(t, "LShift >> 1", a.get(), ret.get()) ||
277       !BN_rshift1(ret.get(), lshift1.get()) ||
278       !ExpectBIGNUMsEqual(t, "LShift >> 1", a.get(), ret.get())) {
279     return false;
280   }
281 
282   // Set the LSB to 1 and test rshift1 again.
283   if (!BN_set_bit(lshift1.get(), 0) ||
284       !BN_div(ret.get(), nullptr /* rem */, lshift1.get(), two.get(), ctx) ||
285       !ExpectBIGNUMsEqual(t, "(LShift1 | 1) / 2", a.get(), ret.get()) ||
286       !BN_rshift1(ret.get(), lshift1.get()) ||
287       !ExpectBIGNUMsEqual(t, "(LShift | 1) >> 1", a.get(), ret.get())) {
288     return false;
289   }
290 
291   return true;
292 }
293 
TestLShift(FileTest * t,BN_CTX * ctx)294 static bool TestLShift(FileTest *t, BN_CTX *ctx) {
295   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
296   bssl::UniquePtr<BIGNUM> lshift = GetBIGNUM(t, "LShift");
297   int n = 0;
298   if (!a || !lshift || !GetInt(t, &n, "N")) {
299     return false;
300   }
301 
302   bssl::UniquePtr<BIGNUM> ret(BN_new());
303   if (!ret ||
304       !BN_lshift(ret.get(), a.get(), n) ||
305       !ExpectBIGNUMsEqual(t, "A << N", lshift.get(), ret.get()) ||
306       !BN_rshift(ret.get(), lshift.get(), n) ||
307       !ExpectBIGNUMsEqual(t, "A >> N", a.get(), ret.get())) {
308     return false;
309   }
310 
311   return true;
312 }
313 
TestRShift(FileTest * t,BN_CTX * ctx)314 static bool TestRShift(FileTest *t, BN_CTX *ctx) {
315   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
316   bssl::UniquePtr<BIGNUM> rshift = GetBIGNUM(t, "RShift");
317   int n = 0;
318   if (!a || !rshift || !GetInt(t, &n, "N")) {
319     return false;
320   }
321 
322   bssl::UniquePtr<BIGNUM> ret(BN_new());
323   if (!ret ||
324       !BN_rshift(ret.get(), a.get(), n) ||
325       !ExpectBIGNUMsEqual(t, "A >> N", rshift.get(), ret.get())) {
326     return false;
327   }
328 
329   return true;
330 }
331 
TestSquare(FileTest * t,BN_CTX * ctx)332 static bool TestSquare(FileTest *t, BN_CTX *ctx) {
333   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
334   bssl::UniquePtr<BIGNUM> square = GetBIGNUM(t, "Square");
335   bssl::UniquePtr<BIGNUM> zero(BN_new());
336   if (!a || !square || !zero) {
337     return false;
338   }
339 
340   BN_zero(zero.get());
341 
342   bssl::UniquePtr<BIGNUM> ret(BN_new()), remainder(BN_new());
343   if (!ret || !remainder ||
344       !BN_sqr(ret.get(), a.get(), ctx) ||
345       !ExpectBIGNUMsEqual(t, "A^2", square.get(), ret.get()) ||
346       !BN_mul(ret.get(), a.get(), a.get(), ctx) ||
347       !ExpectBIGNUMsEqual(t, "A * A", square.get(), ret.get()) ||
348       !BN_div(ret.get(), remainder.get(), square.get(), a.get(), ctx) ||
349       !ExpectBIGNUMsEqual(t, "Square / A", a.get(), ret.get()) ||
350       !ExpectBIGNUMsEqual(t, "Square % A", zero.get(), remainder.get())) {
351     return false;
352   }
353 
354   BN_set_negative(a.get(), 0);
355   if (!BN_sqrt(ret.get(), square.get(), ctx) ||
356       !ExpectBIGNUMsEqual(t, "sqrt(Square)", a.get(), ret.get())) {
357     return false;
358   }
359 
360   // BN_sqrt should fail on non-squares and negative numbers.
361   if (!BN_is_zero(square.get())) {
362     bssl::UniquePtr<BIGNUM> tmp(BN_new());
363     if (!tmp || !BN_copy(tmp.get(), square.get())) {
364       return false;
365     }
366     BN_set_negative(tmp.get(), 1);
367 
368     if (BN_sqrt(ret.get(), tmp.get(), ctx)) {
369       t->PrintLine("BN_sqrt succeeded on a negative number");
370       return false;
371     }
372     ERR_clear_error();
373 
374     BN_set_negative(tmp.get(), 0);
375     if (!BN_add(tmp.get(), tmp.get(), BN_value_one())) {
376       return false;
377     }
378     if (BN_sqrt(ret.get(), tmp.get(), ctx)) {
379       t->PrintLine("BN_sqrt succeeded on a non-square");
380       return false;
381     }
382     ERR_clear_error();
383   }
384 
385   return true;
386 }
387 
TestProduct(FileTest * t,BN_CTX * ctx)388 static bool TestProduct(FileTest *t, BN_CTX *ctx) {
389   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
390   bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
391   bssl::UniquePtr<BIGNUM> product = GetBIGNUM(t, "Product");
392   bssl::UniquePtr<BIGNUM> zero(BN_new());
393   if (!a || !b || !product || !zero) {
394     return false;
395   }
396 
397   BN_zero(zero.get());
398 
399   bssl::UniquePtr<BIGNUM> ret(BN_new()), remainder(BN_new());
400   if (!ret || !remainder ||
401       !BN_mul(ret.get(), a.get(), b.get(), ctx) ||
402       !ExpectBIGNUMsEqual(t, "A * B", product.get(), ret.get()) ||
403       !BN_div(ret.get(), remainder.get(), product.get(), a.get(), ctx) ||
404       !ExpectBIGNUMsEqual(t, "Product / A", b.get(), ret.get()) ||
405       !ExpectBIGNUMsEqual(t, "Product % A", zero.get(), remainder.get()) ||
406       !BN_div(ret.get(), remainder.get(), product.get(), b.get(), ctx) ||
407       !ExpectBIGNUMsEqual(t, "Product / B", a.get(), ret.get()) ||
408       !ExpectBIGNUMsEqual(t, "Product % B", zero.get(), remainder.get())) {
409     return false;
410   }
411 
412   return true;
413 }
414 
TestQuotient(FileTest * t,BN_CTX * ctx)415 static bool TestQuotient(FileTest *t, BN_CTX *ctx) {
416   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
417   bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
418   bssl::UniquePtr<BIGNUM> quotient = GetBIGNUM(t, "Quotient");
419   bssl::UniquePtr<BIGNUM> remainder = GetBIGNUM(t, "Remainder");
420   if (!a || !b || !quotient || !remainder) {
421     return false;
422   }
423 
424   bssl::UniquePtr<BIGNUM> ret(BN_new()), ret2(BN_new());
425   if (!ret || !ret2 ||
426       !BN_div(ret.get(), ret2.get(), a.get(), b.get(), ctx) ||
427       !ExpectBIGNUMsEqual(t, "A / B", quotient.get(), ret.get()) ||
428       !ExpectBIGNUMsEqual(t, "A % B", remainder.get(), ret2.get()) ||
429       !BN_mul(ret.get(), quotient.get(), b.get(), ctx) ||
430       !BN_add(ret.get(), ret.get(), remainder.get()) ||
431       !ExpectBIGNUMsEqual(t, "Quotient * B + Remainder", a.get(), ret.get())) {
432     return false;
433   }
434 
435   // Test with |BN_mod_word| and |BN_div_word| if the divisor is small enough.
436   BN_ULONG b_word = BN_get_word(b.get());
437   if (!BN_is_negative(b.get()) && b_word != (BN_ULONG)-1) {
438     BN_ULONG remainder_word = BN_get_word(remainder.get());
439     assert(remainder_word != (BN_ULONG)-1);
440     if (!BN_copy(ret.get(), a.get())) {
441       return false;
442     }
443     BN_ULONG ret_word = BN_div_word(ret.get(), b_word);
444     if (ret_word != remainder_word) {
445       t->PrintLine("Got A %% B (word) = " BN_HEX_FMT1 ", wanted " BN_HEX_FMT1
446                    "\n",
447                    ret_word, remainder_word);
448       return false;
449     }
450     if (!ExpectBIGNUMsEqual(t, "A / B (word)", quotient.get(), ret.get())) {
451       return false;
452     }
453 
454     ret_word = BN_mod_word(a.get(), b_word);
455     if (ret_word != remainder_word) {
456       t->PrintLine("Got A %% B (word) = " BN_HEX_FMT1 ", wanted " BN_HEX_FMT1
457                    "\n",
458                    ret_word, remainder_word);
459       return false;
460     }
461   }
462 
463   // Test BN_nnmod.
464   if (!BN_is_negative(b.get())) {
465     bssl::UniquePtr<BIGNUM> nnmod(BN_new());
466     if (!nnmod ||
467         !BN_copy(nnmod.get(), remainder.get()) ||
468         (BN_is_negative(nnmod.get()) &&
469          !BN_add(nnmod.get(), nnmod.get(), b.get())) ||
470         !BN_nnmod(ret.get(), a.get(), b.get(), ctx) ||
471         !ExpectBIGNUMsEqual(t, "A % B (non-negative)", nnmod.get(),
472                             ret.get())) {
473       return false;
474     }
475   }
476 
477   return true;
478 }
479 
TestModMul(FileTest * t,BN_CTX * ctx)480 static bool TestModMul(FileTest *t, BN_CTX *ctx) {
481   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
482   bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
483   bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
484   bssl::UniquePtr<BIGNUM> mod_mul = GetBIGNUM(t, "ModMul");
485   if (!a || !b || !m || !mod_mul) {
486     return false;
487   }
488 
489   bssl::UniquePtr<BIGNUM> ret(BN_new());
490   if (!ret ||
491       !BN_mod_mul(ret.get(), a.get(), b.get(), m.get(), ctx) ||
492       !ExpectBIGNUMsEqual(t, "A * B (mod M)", mod_mul.get(), ret.get())) {
493     return false;
494   }
495 
496   if (BN_is_odd(m.get())) {
497     // Reduce |a| and |b| and test the Montgomery version.
498     bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
499     bssl::UniquePtr<BIGNUM> a_tmp(BN_new()), b_tmp(BN_new());
500     if (!mont || !a_tmp || !b_tmp ||
501         !BN_MONT_CTX_set(mont.get(), m.get(), ctx) ||
502         !BN_nnmod(a_tmp.get(), a.get(), m.get(), ctx) ||
503         !BN_nnmod(b_tmp.get(), b.get(), m.get(), ctx) ||
504         !BN_to_montgomery(a_tmp.get(), a_tmp.get(), mont.get(), ctx) ||
505         !BN_to_montgomery(b_tmp.get(), b_tmp.get(), mont.get(), ctx) ||
506         !BN_mod_mul_montgomery(ret.get(), a_tmp.get(), b_tmp.get(), mont.get(),
507                                ctx) ||
508         !BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx) ||
509         !ExpectBIGNUMsEqual(t, "A * B (mod M) (Montgomery)",
510                             mod_mul.get(), ret.get())) {
511       return false;
512     }
513   }
514 
515   return true;
516 }
517 
TestModSquare(FileTest * t,BN_CTX * ctx)518 static bool TestModSquare(FileTest *t, BN_CTX *ctx) {
519   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
520   bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
521   bssl::UniquePtr<BIGNUM> mod_square = GetBIGNUM(t, "ModSquare");
522   if (!a || !m || !mod_square) {
523     return false;
524   }
525 
526   bssl::UniquePtr<BIGNUM> a_copy(BN_new());
527   bssl::UniquePtr<BIGNUM> ret(BN_new());
528   if (!ret || !a_copy ||
529       !BN_mod_mul(ret.get(), a.get(), a.get(), m.get(), ctx) ||
530       !ExpectBIGNUMsEqual(t, "A * A (mod M)", mod_square.get(), ret.get()) ||
531       // Repeat the operation with |a_copy|.
532       !BN_copy(a_copy.get(), a.get()) ||
533       !BN_mod_mul(ret.get(), a.get(), a_copy.get(), m.get(), ctx) ||
534       !ExpectBIGNUMsEqual(t, "A * A_copy (mod M)", mod_square.get(),
535                           ret.get())) {
536     return false;
537   }
538 
539   if (BN_is_odd(m.get())) {
540     // Reduce |a| and test the Montgomery version.
541     bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
542     bssl::UniquePtr<BIGNUM> a_tmp(BN_new());
543     if (!mont || !a_tmp ||
544         !BN_MONT_CTX_set(mont.get(), m.get(), ctx) ||
545         !BN_nnmod(a_tmp.get(), a.get(), m.get(), ctx) ||
546         !BN_to_montgomery(a_tmp.get(), a_tmp.get(), mont.get(), ctx) ||
547         !BN_mod_mul_montgomery(ret.get(), a_tmp.get(), a_tmp.get(), mont.get(),
548                                ctx) ||
549         !BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx) ||
550         !ExpectBIGNUMsEqual(t, "A * A (mod M) (Montgomery)",
551                             mod_square.get(), ret.get()) ||
552         // Repeat the operation with |a_copy|.
553         !BN_copy(a_copy.get(), a_tmp.get()) ||
554         !BN_mod_mul_montgomery(ret.get(), a_tmp.get(), a_copy.get(), mont.get(),
555                                ctx) ||
556         !BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx) ||
557         !ExpectBIGNUMsEqual(t, "A * A_copy (mod M) (Montgomery)",
558                             mod_square.get(), ret.get())) {
559       return false;
560     }
561   }
562 
563   return true;
564 }
565 
TestModExp(FileTest * t,BN_CTX * ctx)566 static bool TestModExp(FileTest *t, BN_CTX *ctx) {
567   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
568   bssl::UniquePtr<BIGNUM> e = GetBIGNUM(t, "E");
569   bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
570   bssl::UniquePtr<BIGNUM> mod_exp = GetBIGNUM(t, "ModExp");
571   if (!a || !e || !m || !mod_exp) {
572     return false;
573   }
574 
575   bssl::UniquePtr<BIGNUM> ret(BN_new());
576   if (!ret ||
577       !BN_mod_exp(ret.get(), a.get(), e.get(), m.get(), ctx) ||
578       !ExpectBIGNUMsEqual(t, "A ^ E (mod M)", mod_exp.get(), ret.get())) {
579     return false;
580   }
581 
582   if (BN_is_odd(m.get())) {
583     if (!BN_mod_exp_mont(ret.get(), a.get(), e.get(), m.get(), ctx, NULL) ||
584         !ExpectBIGNUMsEqual(t, "A ^ E (mod M) (Montgomery)", mod_exp.get(),
585                             ret.get()) ||
586         !BN_mod_exp_mont_consttime(ret.get(), a.get(), e.get(), m.get(), ctx,
587                                    NULL) ||
588         !ExpectBIGNUMsEqual(t, "A ^ E (mod M) (constant-time)", mod_exp.get(),
589                             ret.get())) {
590       return false;
591     }
592   }
593 
594   return true;
595 }
596 
TestExp(FileTest * t,BN_CTX * ctx)597 static bool TestExp(FileTest *t, BN_CTX *ctx) {
598   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
599   bssl::UniquePtr<BIGNUM> e = GetBIGNUM(t, "E");
600   bssl::UniquePtr<BIGNUM> exp = GetBIGNUM(t, "Exp");
601   if (!a || !e || !exp) {
602     return false;
603   }
604 
605   bssl::UniquePtr<BIGNUM> ret(BN_new());
606   if (!ret ||
607       !BN_exp(ret.get(), a.get(), e.get(), ctx) ||
608       !ExpectBIGNUMsEqual(t, "A ^ E", exp.get(), ret.get())) {
609     return false;
610   }
611 
612   return true;
613 }
614 
TestModSqrt(FileTest * t,BN_CTX * ctx)615 static bool TestModSqrt(FileTest *t, BN_CTX *ctx) {
616   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
617   bssl::UniquePtr<BIGNUM> p = GetBIGNUM(t, "P");
618   bssl::UniquePtr<BIGNUM> mod_sqrt = GetBIGNUM(t, "ModSqrt");
619   bssl::UniquePtr<BIGNUM> mod_sqrt2(BN_new());
620   if (!a || !p || !mod_sqrt || !mod_sqrt2 ||
621       // There are two possible answers.
622       !BN_sub(mod_sqrt2.get(), p.get(), mod_sqrt.get())) {
623     return false;
624   }
625 
626   // -0 is 0, not P.
627   if (BN_is_zero(mod_sqrt.get())) {
628     BN_zero(mod_sqrt2.get());
629   }
630 
631   bssl::UniquePtr<BIGNUM> ret(BN_new());
632   if (!ret ||
633       !BN_mod_sqrt(ret.get(), a.get(), p.get(), ctx)) {
634     return false;
635   }
636 
637   if (BN_cmp(ret.get(), mod_sqrt2.get()) != 0 &&
638       !ExpectBIGNUMsEqual(t, "sqrt(A) (mod P)", mod_sqrt.get(), ret.get())) {
639     return false;
640   }
641 
642   return true;
643 }
644 
TestNotModSquare(FileTest * t,BN_CTX * ctx)645 static bool TestNotModSquare(FileTest *t, BN_CTX *ctx) {
646   bssl::UniquePtr<BIGNUM> not_mod_square = GetBIGNUM(t, "NotModSquare");
647   bssl::UniquePtr<BIGNUM> p = GetBIGNUM(t, "P");
648   bssl::UniquePtr<BIGNUM> ret(BN_new());
649   if (!not_mod_square || !p || !ret) {
650     return false;
651   }
652 
653   if (BN_mod_sqrt(ret.get(), not_mod_square.get(), p.get(), ctx)) {
654     t->PrintLine("BN_mod_sqrt unexpectedly succeeded.");
655     return false;
656   }
657 
658   uint32_t err = ERR_peek_error();
659   if (ERR_GET_LIB(err) == ERR_LIB_BN &&
660       ERR_GET_REASON(err) == BN_R_NOT_A_SQUARE) {
661     ERR_clear_error();
662     return true;
663   }
664 
665   return false;
666 }
667 
TestModInv(FileTest * t,BN_CTX * ctx)668 static bool TestModInv(FileTest *t, BN_CTX *ctx) {
669   bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
670   bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
671   bssl::UniquePtr<BIGNUM> mod_inv = GetBIGNUM(t, "ModInv");
672   if (!a || !m || !mod_inv) {
673     return false;
674   }
675 
676   bssl::UniquePtr<BIGNUM> ret(BN_new());
677   if (!ret ||
678       !BN_mod_inverse(ret.get(), a.get(), m.get(), ctx) ||
679       !ExpectBIGNUMsEqual(t, "inv(A) (mod M)", mod_inv.get(), ret.get())) {
680     return false;
681   }
682 
683   return true;
684 }
685 
686 struct Test {
687   const char *name;
688   bool (*func)(FileTest *t, BN_CTX *ctx);
689 };
690 
691 static const Test kTests[] = {
692     {"Sum", TestSum},
693     {"LShift1", TestLShift1},
694     {"LShift", TestLShift},
695     {"RShift", TestRShift},
696     {"Square", TestSquare},
697     {"Product", TestProduct},
698     {"Quotient", TestQuotient},
699     {"ModMul", TestModMul},
700     {"ModSquare", TestModSquare},
701     {"ModExp", TestModExp},
702     {"Exp", TestExp},
703     {"ModSqrt", TestModSqrt},
704     {"NotModSquare", TestNotModSquare},
705     {"ModInv", TestModInv},
706 };
707 
RunTest(FileTest * t,void * arg)708 static bool RunTest(FileTest *t, void *arg) {
709   BN_CTX *ctx = reinterpret_cast<BN_CTX *>(arg);
710   for (const Test &test : kTests) {
711     if (t->GetType() != test.name) {
712       continue;
713     }
714     return test.func(t, ctx);
715   }
716   t->PrintLine("Unknown test type: %s", t->GetType().c_str());
717   return false;
718 }
719 
TestBN2BinPadded(BN_CTX * ctx)720 static bool TestBN2BinPadded(BN_CTX *ctx) {
721   uint8_t zeros[256], out[256], reference[128];
722 
723   OPENSSL_memset(zeros, 0, sizeof(zeros));
724 
725   // Test edge case at 0.
726   bssl::UniquePtr<BIGNUM> n(BN_new());
727   if (!n || !BN_bn2bin_padded(NULL, 0, n.get())) {
728     fprintf(stderr,
729             "BN_bn2bin_padded failed to encode 0 in an empty buffer.\n");
730     return false;
731   }
732   OPENSSL_memset(out, -1, sizeof(out));
733   if (!BN_bn2bin_padded(out, sizeof(out), n.get())) {
734     fprintf(stderr,
735             "BN_bn2bin_padded failed to encode 0 in a non-empty buffer.\n");
736     return false;
737   }
738   if (OPENSSL_memcmp(zeros, out, sizeof(out))) {
739     fprintf(stderr, "BN_bn2bin_padded did not zero buffer.\n");
740     return false;
741   }
742 
743   // Test a random numbers at various byte lengths.
744   for (size_t bytes = 128 - 7; bytes <= 128; bytes++) {
745     if (!BN_rand(n.get(), bytes * 8, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY)) {
746       ERR_print_errors_fp(stderr);
747       return false;
748     }
749     if (BN_num_bytes(n.get()) != bytes ||
750         BN_bn2bin(n.get(), reference) != bytes) {
751       fprintf(stderr, "Bad result from BN_rand; bytes.\n");
752       return false;
753     }
754     // Empty buffer should fail.
755     if (BN_bn2bin_padded(NULL, 0, n.get())) {
756       fprintf(stderr,
757               "BN_bn2bin_padded incorrectly succeeded on empty buffer.\n");
758       return false;
759     }
760     // One byte short should fail.
761     if (BN_bn2bin_padded(out, bytes - 1, n.get())) {
762       fprintf(stderr, "BN_bn2bin_padded incorrectly succeeded on short.\n");
763       return false;
764     }
765     // Exactly right size should encode.
766     if (!BN_bn2bin_padded(out, bytes, n.get()) ||
767         OPENSSL_memcmp(out, reference, bytes) != 0) {
768       fprintf(stderr, "BN_bn2bin_padded gave a bad result.\n");
769       return false;
770     }
771     // Pad up one byte extra.
772     if (!BN_bn2bin_padded(out, bytes + 1, n.get()) ||
773         OPENSSL_memcmp(out + 1, reference, bytes) ||
774         OPENSSL_memcmp(out, zeros, 1)) {
775       fprintf(stderr, "BN_bn2bin_padded gave a bad result.\n");
776       return false;
777     }
778     // Pad up to 256.
779     if (!BN_bn2bin_padded(out, sizeof(out), n.get()) ||
780         OPENSSL_memcmp(out + sizeof(out) - bytes, reference, bytes) ||
781         OPENSSL_memcmp(out, zeros, sizeof(out) - bytes)) {
782       fprintf(stderr, "BN_bn2bin_padded gave a bad result.\n");
783       return false;
784     }
785   }
786 
787   return true;
788 }
789 
TestLittleEndian()790 static bool TestLittleEndian() {
791   bssl::UniquePtr<BIGNUM> x(BN_new());
792   bssl::UniquePtr<BIGNUM> y(BN_new());
793   if (!x || !y) {
794     fprintf(stderr, "BN_new failed to malloc.\n");
795     return false;
796   }
797 
798   // Test edge case at 0. Fill |out| with garbage to ensure |BN_bn2le_padded|
799   // wrote the result.
800   uint8_t out[256], zeros[256];
801   OPENSSL_memset(out, -1, sizeof(out));
802   OPENSSL_memset(zeros, 0, sizeof(zeros));
803   if (!BN_bn2le_padded(out, sizeof(out), x.get()) ||
804       OPENSSL_memcmp(zeros, out, sizeof(out))) {
805     fprintf(stderr, "BN_bn2le_padded failed to encode 0.\n");
806     return false;
807   }
808 
809   if (!BN_le2bn(out, sizeof(out), y.get()) ||
810       BN_cmp(x.get(), y.get()) != 0) {
811     fprintf(stderr, "BN_le2bn failed to decode 0 correctly.\n");
812     return false;
813   }
814 
815   // Test random numbers at various byte lengths.
816   for (size_t bytes = 128 - 7; bytes <= 128; bytes++) {
817     if (!BN_rand(x.get(), bytes * 8, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY)) {
818       ERR_print_errors_fp(stderr);
819       return false;
820     }
821 
822     // Fill |out| with garbage to ensure |BN_bn2le_padded| wrote the result.
823     OPENSSL_memset(out, -1, sizeof(out));
824     if (!BN_bn2le_padded(out, sizeof(out), x.get())) {
825       fprintf(stderr, "BN_bn2le_padded failed to encode random value.\n");
826       return false;
827     }
828 
829     // Compute the expected value by reversing the big-endian output.
830     uint8_t expected[sizeof(out)];
831     if (!BN_bn2bin_padded(expected, sizeof(expected), x.get())) {
832       return false;
833     }
834     for (size_t i = 0; i < sizeof(expected) / 2; i++) {
835       uint8_t tmp = expected[i];
836       expected[i] = expected[sizeof(expected) - 1 - i];
837       expected[sizeof(expected) - 1 - i] = tmp;
838     }
839 
840     if (OPENSSL_memcmp(expected, out, sizeof(out))) {
841       fprintf(stderr, "BN_bn2le_padded failed to encode value correctly.\n");
842       hexdump(stderr, "Expected: ", expected, sizeof(expected));
843       hexdump(stderr, "Got:      ", out, sizeof(out));
844       return false;
845     }
846 
847     // Make sure the decoding produces the same BIGNUM.
848     if (!BN_le2bn(out, bytes, y.get()) ||
849         BN_cmp(x.get(), y.get()) != 0) {
850       bssl::UniquePtr<char> x_hex(BN_bn2hex(x.get())),
851           y_hex(BN_bn2hex(y.get()));
852       if (!x_hex || !y_hex) {
853         return false;
854       }
855       fprintf(stderr, "BN_le2bn failed to decode value correctly.\n");
856       fprintf(stderr, "Expected: %s\n", x_hex.get());
857       hexdump(stderr, "Encoding: ", out, bytes);
858       fprintf(stderr, "Got:      %s\n", y_hex.get());
859       return false;
860     }
861   }
862 
863   return true;
864 }
865 
DecimalToBIGNUM(bssl::UniquePtr<BIGNUM> * out,const char * in)866 static int DecimalToBIGNUM(bssl::UniquePtr<BIGNUM> *out, const char *in) {
867   BIGNUM *raw = NULL;
868   int ret = BN_dec2bn(&raw, in);
869   out->reset(raw);
870   return ret;
871 }
872 
TestDec2BN(BN_CTX * ctx)873 static bool TestDec2BN(BN_CTX *ctx) {
874   bssl::UniquePtr<BIGNUM> bn;
875   int ret = DecimalToBIGNUM(&bn, "0");
876   if (ret != 1 || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
877     fprintf(stderr, "BN_dec2bn gave a bad result.\n");
878     return false;
879   }
880 
881   ret = DecimalToBIGNUM(&bn, "256");
882   if (ret != 3 || !BN_is_word(bn.get(), 256) || BN_is_negative(bn.get())) {
883     fprintf(stderr, "BN_dec2bn gave a bad result.\n");
884     return false;
885   }
886 
887   ret = DecimalToBIGNUM(&bn, "-42");
888   if (ret != 3 || !BN_abs_is_word(bn.get(), 42) || !BN_is_negative(bn.get())) {
889     fprintf(stderr, "BN_dec2bn gave a bad result.\n");
890     return false;
891   }
892 
893   ret = DecimalToBIGNUM(&bn, "-0");
894   if (ret != 2 || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
895     fprintf(stderr, "BN_dec2bn gave a bad result.\n");
896     return false;
897   }
898 
899   ret = DecimalToBIGNUM(&bn, "42trailing garbage is ignored");
900   if (ret != 2 || !BN_abs_is_word(bn.get(), 42) || BN_is_negative(bn.get())) {
901     fprintf(stderr, "BN_dec2bn gave a bad result.\n");
902     return false;
903   }
904 
905   return true;
906 }
907 
TestHex2BN(BN_CTX * ctx)908 static bool TestHex2BN(BN_CTX *ctx) {
909   bssl::UniquePtr<BIGNUM> bn;
910   int ret = HexToBIGNUM(&bn, "0");
911   if (ret != 1 || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
912     fprintf(stderr, "BN_hex2bn gave a bad result.\n");
913     return false;
914   }
915 
916   ret = HexToBIGNUM(&bn, "256");
917   if (ret != 3 || !BN_is_word(bn.get(), 0x256) || BN_is_negative(bn.get())) {
918     fprintf(stderr, "BN_hex2bn gave a bad result.\n");
919     return false;
920   }
921 
922   ret = HexToBIGNUM(&bn, "-42");
923   if (ret != 3 || !BN_abs_is_word(bn.get(), 0x42) || !BN_is_negative(bn.get())) {
924     fprintf(stderr, "BN_hex2bn gave a bad result.\n");
925     return false;
926   }
927 
928   ret = HexToBIGNUM(&bn, "-0");
929   if (ret != 2 || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
930     fprintf(stderr, "BN_hex2bn gave a bad result.\n");
931     return false;
932   }
933 
934   ret = HexToBIGNUM(&bn, "abctrailing garbage is ignored");
935   if (ret != 3 || !BN_is_word(bn.get(), 0xabc) || BN_is_negative(bn.get())) {
936     fprintf(stderr, "BN_hex2bn gave a bad result.\n");
937     return false;
938   }
939 
940   return true;
941 }
942 
ASCIIToBIGNUM(const char * in)943 static bssl::UniquePtr<BIGNUM> ASCIIToBIGNUM(const char *in) {
944   BIGNUM *raw = NULL;
945   if (!BN_asc2bn(&raw, in)) {
946     return nullptr;
947   }
948   return bssl::UniquePtr<BIGNUM>(raw);
949 }
950 
TestASC2BN(BN_CTX * ctx)951 static bool TestASC2BN(BN_CTX *ctx) {
952   bssl::UniquePtr<BIGNUM> bn = ASCIIToBIGNUM("0");
953   if (!bn || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
954     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
955     return false;
956   }
957 
958   bn = ASCIIToBIGNUM("256");
959   if (!bn || !BN_is_word(bn.get(), 256) || BN_is_negative(bn.get())) {
960     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
961     return false;
962   }
963 
964   bn = ASCIIToBIGNUM("-42");
965   if (!bn || !BN_abs_is_word(bn.get(), 42) || !BN_is_negative(bn.get())) {
966     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
967     return false;
968   }
969 
970   bn = ASCIIToBIGNUM("0x1234");
971   if (!bn || !BN_is_word(bn.get(), 0x1234) || BN_is_negative(bn.get())) {
972     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
973     return false;
974   }
975 
976   bn = ASCIIToBIGNUM("0X1234");
977   if (!bn || !BN_is_word(bn.get(), 0x1234) || BN_is_negative(bn.get())) {
978     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
979     return false;
980   }
981 
982   bn = ASCIIToBIGNUM("-0xabcd");
983   if (!bn || !BN_abs_is_word(bn.get(), 0xabcd) || !BN_is_negative(bn.get())) {
984     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
985     return false;
986   }
987 
988   bn = ASCIIToBIGNUM("-0");
989   if (!bn || !BN_is_zero(bn.get()) || BN_is_negative(bn.get())) {
990     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
991     return false;
992   }
993 
994   bn = ASCIIToBIGNUM("123trailing garbage is ignored");
995   if (!bn || !BN_is_word(bn.get(), 123) || BN_is_negative(bn.get())) {
996     fprintf(stderr, "BN_asc2bn gave a bad result.\n");
997     return false;
998   }
999 
1000   return true;
1001 }
1002 
1003 struct MPITest {
1004   const char *base10;
1005   const char *mpi;
1006   size_t mpi_len;
1007 };
1008 
1009 static const MPITest kMPITests[] = {
1010   { "0", "\x00\x00\x00\x00", 4 },
1011   { "1", "\x00\x00\x00\x01\x01", 5 },
1012   { "-1", "\x00\x00\x00\x01\x81", 5 },
1013   { "128", "\x00\x00\x00\x02\x00\x80", 6 },
1014   { "256", "\x00\x00\x00\x02\x01\x00", 6 },
1015   { "-256", "\x00\x00\x00\x02\x81\x00", 6 },
1016 };
1017 
TestMPI()1018 static bool TestMPI() {
1019   uint8_t scratch[8];
1020 
1021   for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(kMPITests); i++) {
1022     const MPITest &test = kMPITests[i];
1023     bssl::UniquePtr<BIGNUM> bn(ASCIIToBIGNUM(test.base10));
1024     if (!bn) {
1025       return false;
1026     }
1027 
1028     const size_t mpi_len = BN_bn2mpi(bn.get(), NULL);
1029     if (mpi_len > sizeof(scratch)) {
1030       fprintf(stderr, "MPI test #%u: MPI size is too large to test.\n",
1031               (unsigned)i);
1032       return false;
1033     }
1034 
1035     const size_t mpi_len2 = BN_bn2mpi(bn.get(), scratch);
1036     if (mpi_len != mpi_len2) {
1037       fprintf(stderr, "MPI test #%u: length changes.\n", (unsigned)i);
1038       return false;
1039     }
1040 
1041     if (mpi_len != test.mpi_len ||
1042         OPENSSL_memcmp(test.mpi, scratch, mpi_len) != 0) {
1043       fprintf(stderr, "MPI test #%u failed:\n", (unsigned)i);
1044       hexdump(stderr, "Expected: ", test.mpi, test.mpi_len);
1045       hexdump(stderr, "Got:      ", scratch, mpi_len);
1046       return false;
1047     }
1048 
1049     bssl::UniquePtr<BIGNUM> bn2(BN_mpi2bn(scratch, mpi_len, NULL));
1050     if (bn2.get() == nullptr) {
1051       fprintf(stderr, "MPI test #%u: failed to parse\n", (unsigned)i);
1052       return false;
1053     }
1054 
1055     if (BN_cmp(bn.get(), bn2.get()) != 0) {
1056       fprintf(stderr, "MPI test #%u: wrong result\n", (unsigned)i);
1057       return false;
1058     }
1059   }
1060 
1061   return true;
1062 }
1063 
TestRand()1064 static bool TestRand() {
1065   bssl::UniquePtr<BIGNUM> bn(BN_new());
1066   if (!bn) {
1067     return false;
1068   }
1069 
1070   // Test BN_rand accounts for degenerate cases with |top| and |bottom|
1071   // parameters.
1072   if (!BN_rand(bn.get(), 0, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY) ||
1073       !BN_is_zero(bn.get())) {
1074     fprintf(stderr, "BN_rand gave a bad result.\n");
1075     return false;
1076   }
1077   if (!BN_rand(bn.get(), 0, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ODD) ||
1078       !BN_is_zero(bn.get())) {
1079     fprintf(stderr, "BN_rand gave a bad result.\n");
1080     return false;
1081   }
1082 
1083   if (!BN_rand(bn.get(), 1, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY) ||
1084       !BN_is_word(bn.get(), 1)) {
1085     fprintf(stderr, "BN_rand gave a bad result.\n");
1086     return false;
1087   }
1088   if (!BN_rand(bn.get(), 1, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ANY) ||
1089       !BN_is_word(bn.get(), 1)) {
1090     fprintf(stderr, "BN_rand gave a bad result.\n");
1091     return false;
1092   }
1093   if (!BN_rand(bn.get(), 1, BN_RAND_TOP_ANY, BN_RAND_BOTTOM_ODD) ||
1094       !BN_is_word(bn.get(), 1)) {
1095     fprintf(stderr, "BN_rand gave a bad result.\n");
1096     return false;
1097   }
1098 
1099   if (!BN_rand(bn.get(), 2, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ANY) ||
1100       !BN_is_word(bn.get(), 3)) {
1101     fprintf(stderr, "BN_rand gave a bad result.\n");
1102     return false;
1103   }
1104 
1105   return true;
1106 }
1107 
1108 struct ASN1Test {
1109   const char *value_ascii;
1110   const char *der;
1111   size_t der_len;
1112 };
1113 
1114 static const ASN1Test kASN1Tests[] = {
1115     {"0", "\x02\x01\x00", 3},
1116     {"1", "\x02\x01\x01", 3},
1117     {"127", "\x02\x01\x7f", 3},
1118     {"128", "\x02\x02\x00\x80", 4},
1119     {"0xdeadbeef", "\x02\x05\x00\xde\xad\xbe\xef", 7},
1120     {"0x0102030405060708",
1121      "\x02\x08\x01\x02\x03\x04\x05\x06\x07\x08", 10},
1122     {"0xffffffffffffffff",
1123       "\x02\x09\x00\xff\xff\xff\xff\xff\xff\xff\xff", 11},
1124 };
1125 
1126 struct ASN1InvalidTest {
1127   const char *der;
1128   size_t der_len;
1129 };
1130 
1131 static const ASN1InvalidTest kASN1InvalidTests[] = {
1132     // Bad tag.
1133     {"\x03\x01\x00", 3},
1134     // Empty contents.
1135     {"\x02\x00", 2},
1136 };
1137 
1138 // kASN1BuggyTests contains incorrect encodings and the corresponding, expected
1139 // results of |BN_parse_asn1_unsigned_buggy| given that input.
1140 static const ASN1Test kASN1BuggyTests[] = {
1141     // Negative numbers.
1142     {"128", "\x02\x01\x80", 3},
1143     {"255", "\x02\x01\xff", 3},
1144     // Unnecessary leading zeros.
1145     {"1", "\x02\x02\x00\x01", 4},
1146 };
1147 
TestASN1()1148 static bool TestASN1() {
1149   for (const ASN1Test &test : kASN1Tests) {
1150     bssl::UniquePtr<BIGNUM> bn = ASCIIToBIGNUM(test.value_ascii);
1151     if (!bn) {
1152       return false;
1153     }
1154 
1155     // Test that the input is correctly parsed.
1156     bssl::UniquePtr<BIGNUM> bn2(BN_new());
1157     if (!bn2) {
1158       return false;
1159     }
1160     CBS cbs;
1161     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1162     if (!BN_parse_asn1_unsigned(&cbs, bn2.get()) || CBS_len(&cbs) != 0) {
1163       fprintf(stderr, "Parsing ASN.1 INTEGER failed.\n");
1164       return false;
1165     }
1166     if (BN_cmp(bn.get(), bn2.get()) != 0) {
1167       fprintf(stderr, "Bad parse.\n");
1168       return false;
1169     }
1170 
1171     // Test the value serializes correctly.
1172     bssl::ScopedCBB cbb;
1173     uint8_t *der;
1174     size_t der_len;
1175     if (!CBB_init(cbb.get(), 0) ||
1176         !BN_marshal_asn1(cbb.get(), bn.get()) ||
1177         !CBB_finish(cbb.get(), &der, &der_len)) {
1178       return false;
1179     }
1180     bssl::UniquePtr<uint8_t> delete_der(der);
1181     if (der_len != test.der_len ||
1182         OPENSSL_memcmp(der, reinterpret_cast<const uint8_t *>(test.der),
1183                        der_len) != 0) {
1184       fprintf(stderr, "Bad serialization.\n");
1185       return false;
1186     }
1187 
1188     // |BN_parse_asn1_unsigned_buggy| parses all valid input.
1189     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1190     if (!BN_parse_asn1_unsigned_buggy(&cbs, bn2.get()) || CBS_len(&cbs) != 0) {
1191       fprintf(stderr, "Parsing ASN.1 INTEGER failed.\n");
1192       return false;
1193     }
1194     if (BN_cmp(bn.get(), bn2.get()) != 0) {
1195       fprintf(stderr, "Bad parse.\n");
1196       return false;
1197     }
1198   }
1199 
1200   for (const ASN1InvalidTest &test : kASN1InvalidTests) {
1201     bssl::UniquePtr<BIGNUM> bn(BN_new());
1202     if (!bn) {
1203       return false;
1204     }
1205     CBS cbs;
1206     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1207     if (BN_parse_asn1_unsigned(&cbs, bn.get())) {
1208       fprintf(stderr, "Parsed invalid input.\n");
1209       return false;
1210     }
1211     ERR_clear_error();
1212 
1213     // All tests in kASN1InvalidTests are also rejected by
1214     // |BN_parse_asn1_unsigned_buggy|.
1215     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1216     if (BN_parse_asn1_unsigned_buggy(&cbs, bn.get())) {
1217       fprintf(stderr, "Parsed invalid input.\n");
1218       return false;
1219     }
1220     ERR_clear_error();
1221   }
1222 
1223   for (const ASN1Test &test : kASN1BuggyTests) {
1224     // These broken encodings are rejected by |BN_parse_asn1_unsigned|.
1225     bssl::UniquePtr<BIGNUM> bn(BN_new());
1226     if (!bn) {
1227       return false;
1228     }
1229 
1230     CBS cbs;
1231     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1232     if (BN_parse_asn1_unsigned(&cbs, bn.get())) {
1233       fprintf(stderr, "Parsed invalid input.\n");
1234       return false;
1235     }
1236     ERR_clear_error();
1237 
1238     // However |BN_parse_asn1_unsigned_buggy| accepts them.
1239     bssl::UniquePtr<BIGNUM> bn2 = ASCIIToBIGNUM(test.value_ascii);
1240     if (!bn2) {
1241       return false;
1242     }
1243 
1244     CBS_init(&cbs, reinterpret_cast<const uint8_t*>(test.der), test.der_len);
1245     if (!BN_parse_asn1_unsigned_buggy(&cbs, bn.get()) || CBS_len(&cbs) != 0) {
1246       fprintf(stderr, "Parsing (invalid) ASN.1 INTEGER failed.\n");
1247       return false;
1248     }
1249 
1250     if (BN_cmp(bn.get(), bn2.get()) != 0) {
1251       fprintf(stderr, "\"Bad\" parse.\n");
1252       return false;
1253     }
1254   }
1255 
1256   // Serializing negative numbers is not supported.
1257   bssl::UniquePtr<BIGNUM> bn = ASCIIToBIGNUM("-1");
1258   if (!bn) {
1259     return false;
1260   }
1261   bssl::ScopedCBB cbb;
1262   if (!CBB_init(cbb.get(), 0) ||
1263       BN_marshal_asn1(cbb.get(), bn.get())) {
1264     fprintf(stderr, "Serialized negative number.\n");
1265     return false;
1266   }
1267   ERR_clear_error();
1268 
1269   return true;
1270 }
1271 
TestNegativeZero(BN_CTX * ctx)1272 static bool TestNegativeZero(BN_CTX *ctx) {
1273   bssl::UniquePtr<BIGNUM> a(BN_new());
1274   bssl::UniquePtr<BIGNUM> b(BN_new());
1275   bssl::UniquePtr<BIGNUM> c(BN_new());
1276   if (!a || !b || !c) {
1277     return false;
1278   }
1279 
1280   // Test that BN_mul never gives negative zero.
1281   if (!BN_set_word(a.get(), 1)) {
1282     return false;
1283   }
1284   BN_set_negative(a.get(), 1);
1285   BN_zero(b.get());
1286   if (!BN_mul(c.get(), a.get(), b.get(), ctx)) {
1287     return false;
1288   }
1289   if (!BN_is_zero(c.get()) || BN_is_negative(c.get())) {
1290     fprintf(stderr, "Multiplication test failed.\n");
1291     return false;
1292   }
1293 
1294   bssl::UniquePtr<BIGNUM> numerator(BN_new()), denominator(BN_new());
1295   if (!numerator || !denominator) {
1296     return false;
1297   }
1298 
1299   // Test that BN_div never gives negative zero in the quotient.
1300   if (!BN_set_word(numerator.get(), 1) ||
1301       !BN_set_word(denominator.get(), 2)) {
1302     return false;
1303   }
1304   BN_set_negative(numerator.get(), 1);
1305   if (!BN_div(a.get(), b.get(), numerator.get(), denominator.get(), ctx)) {
1306     return false;
1307   }
1308   if (!BN_is_zero(a.get()) || BN_is_negative(a.get())) {
1309     fprintf(stderr, "Incorrect quotient.\n");
1310     return false;
1311   }
1312 
1313   // Test that BN_div never gives negative zero in the remainder.
1314   if (!BN_set_word(denominator.get(), 1)) {
1315     return false;
1316   }
1317   if (!BN_div(a.get(), b.get(), numerator.get(), denominator.get(), ctx)) {
1318     return false;
1319   }
1320   if (!BN_is_zero(b.get()) || BN_is_negative(b.get())) {
1321     fprintf(stderr, "Incorrect remainder.\n");
1322     return false;
1323   }
1324 
1325   // Test that BN_set_negative will not produce a negative zero.
1326   BN_zero(a.get());
1327   BN_set_negative(a.get(), 1);
1328   if (BN_is_negative(a.get())) {
1329     fprintf(stderr, "BN_set_negative produced a negative zero.\n");
1330     return false;
1331   }
1332 
1333   // Test that forcibly creating a negative zero does not break |BN_bn2hex| or
1334   // |BN_bn2dec|.
1335   a->neg = 1;
1336   bssl::UniquePtr<char> dec(BN_bn2dec(a.get()));
1337   bssl::UniquePtr<char> hex(BN_bn2hex(a.get()));
1338   if (!dec || !hex ||
1339       strcmp(dec.get(), "-0") != 0 ||
1340       strcmp(hex.get(), "-0") != 0) {
1341     fprintf(stderr, "BN_bn2dec or BN_bn2hex failed with negative zero.\n");
1342     return false;
1343   }
1344 
1345   // Test that |BN_rshift| and |BN_rshift1| will not produce a negative zero.
1346   if (!BN_set_word(a.get(), 1)) {
1347     return false;
1348   }
1349 
1350   BN_set_negative(a.get(), 1);
1351   if (!BN_rshift(b.get(), a.get(), 1) ||
1352       !BN_rshift1(c.get(), a.get())) {
1353     return false;
1354   }
1355 
1356   if (!BN_is_zero(b.get()) || BN_is_negative(b.get())) {
1357     fprintf(stderr, "BN_rshift(-1, 1) produced the wrong result.\n");
1358     return false;
1359   }
1360 
1361   if (!BN_is_zero(c.get()) || BN_is_negative(c.get())) {
1362     fprintf(stderr, "BN_rshift1(-1) produced the wrong result.\n");
1363     return false;
1364   }
1365 
1366   // Test that |BN_div_word| will not produce a negative zero.
1367   if (BN_div_word(a.get(), 2) == (BN_ULONG)-1) {
1368     return false;
1369   }
1370 
1371   if (!BN_is_zero(a.get()) || BN_is_negative(a.get())) {
1372     fprintf(stderr, "BN_div_word(-1, 2) produced the wrong result.\n");
1373     return false;
1374   }
1375 
1376   return true;
1377 }
1378 
TestBadModulus(BN_CTX * ctx)1379 static bool TestBadModulus(BN_CTX *ctx) {
1380   bssl::UniquePtr<BIGNUM> a(BN_new());
1381   bssl::UniquePtr<BIGNUM> b(BN_new());
1382   bssl::UniquePtr<BIGNUM> zero(BN_new());
1383   bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
1384   if (!a || !b || !zero || !mont) {
1385     return false;
1386   }
1387 
1388   BN_zero(zero.get());
1389 
1390   if (BN_div(a.get(), b.get(), BN_value_one(), zero.get(), ctx)) {
1391     fprintf(stderr, "Division by zero unexpectedly succeeded.\n");
1392     return false;
1393   }
1394   ERR_clear_error();
1395 
1396   if (BN_mod_mul(a.get(), BN_value_one(), BN_value_one(), zero.get(), ctx)) {
1397     fprintf(stderr, "BN_mod_mul with zero modulus unexpectedly succeeded.\n");
1398     return false;
1399   }
1400   ERR_clear_error();
1401 
1402   if (BN_mod_exp(a.get(), BN_value_one(), BN_value_one(), zero.get(), ctx)) {
1403     fprintf(stderr, "BN_mod_exp with zero modulus unexpectedly succeeded.\n");
1404     return 0;
1405   }
1406   ERR_clear_error();
1407 
1408   if (BN_mod_exp_mont(a.get(), BN_value_one(), BN_value_one(), zero.get(), ctx,
1409                       NULL)) {
1410     fprintf(stderr,
1411             "BN_mod_exp_mont with zero modulus unexpectedly succeeded.\n");
1412     return 0;
1413   }
1414   ERR_clear_error();
1415 
1416   if (BN_mod_exp_mont_consttime(a.get(), BN_value_one(), BN_value_one(),
1417                                 zero.get(), ctx, nullptr)) {
1418     fprintf(stderr,
1419             "BN_mod_exp_mont_consttime with zero modulus unexpectedly "
1420             "succeeded.\n");
1421     return 0;
1422   }
1423   ERR_clear_error();
1424 
1425   if (BN_MONT_CTX_set(mont.get(), zero.get(), ctx)) {
1426     fprintf(stderr,
1427             "BN_MONT_CTX_set unexpectedly succeeded for zero modulus.\n");
1428     return false;
1429   }
1430   ERR_clear_error();
1431 
1432   // Some operations also may not be used with an even modulus.
1433 
1434   if (!BN_set_word(b.get(), 16)) {
1435     return false;
1436   }
1437 
1438   if (BN_MONT_CTX_set(mont.get(), b.get(), ctx)) {
1439     fprintf(stderr,
1440             "BN_MONT_CTX_set unexpectedly succeeded for even modulus.\n");
1441     return false;
1442   }
1443   ERR_clear_error();
1444 
1445   if (BN_mod_exp_mont(a.get(), BN_value_one(), BN_value_one(), b.get(), ctx,
1446                       NULL)) {
1447     fprintf(stderr,
1448             "BN_mod_exp_mont with even modulus unexpectedly succeeded.\n");
1449     return 0;
1450   }
1451   ERR_clear_error();
1452 
1453   if (BN_mod_exp_mont_consttime(a.get(), BN_value_one(), BN_value_one(),
1454                                 b.get(), ctx, nullptr)) {
1455     fprintf(stderr,
1456             "BN_mod_exp_mont_consttime with even modulus unexpectedly "
1457             "succeeded.\n");
1458     return 0;
1459   }
1460   ERR_clear_error();
1461 
1462   return true;
1463 }
1464 
1465 // TestExpModZero tests that 1**0 mod 1 == 0.
TestExpModZero()1466 static bool TestExpModZero() {
1467   bssl::UniquePtr<BIGNUM> zero(BN_new()), a(BN_new()), r(BN_new());
1468   if (!zero || !a || !r ||
1469       !BN_rand(a.get(), 1024, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY)) {
1470     return false;
1471   }
1472   BN_zero(zero.get());
1473 
1474   if (!BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), nullptr) ||
1475       !BN_is_zero(r.get()) ||
1476       !BN_mod_exp_mont(r.get(), a.get(), zero.get(), BN_value_one(), nullptr,
1477                        nullptr) ||
1478       !BN_is_zero(r.get()) ||
1479       !BN_mod_exp_mont_consttime(r.get(), a.get(), zero.get(), BN_value_one(),
1480                                  nullptr, nullptr) ||
1481       !BN_is_zero(r.get()) ||
1482       !BN_mod_exp_mont_word(r.get(), 42, zero.get(), BN_value_one(), nullptr,
1483                             nullptr) ||
1484       !BN_is_zero(r.get())) {
1485     return false;
1486   }
1487 
1488   return true;
1489 }
1490 
TestSmallPrime(BN_CTX * ctx)1491 static bool TestSmallPrime(BN_CTX *ctx) {
1492   static const unsigned kBits = 10;
1493 
1494   bssl::UniquePtr<BIGNUM> r(BN_new());
1495   if (!r || !BN_generate_prime_ex(r.get(), static_cast<int>(kBits), 0, NULL,
1496                                   NULL, NULL)) {
1497     return false;
1498   }
1499   if (BN_num_bits(r.get()) != kBits) {
1500     fprintf(stderr, "Expected %u bit prime, got %u bit number\n", kBits,
1501             BN_num_bits(r.get()));
1502     return false;
1503   }
1504 
1505   return true;
1506 }
1507 
TestCmpWord()1508 static bool TestCmpWord() {
1509   static const BN_ULONG kMaxWord = (BN_ULONG)-1;
1510 
1511   bssl::UniquePtr<BIGNUM> r(BN_new());
1512   if (!r ||
1513       !BN_set_word(r.get(), 0)) {
1514     return false;
1515   }
1516 
1517   if (BN_cmp_word(r.get(), 0) != 0 ||
1518       BN_cmp_word(r.get(), 1) >= 0 ||
1519       BN_cmp_word(r.get(), kMaxWord) >= 0) {
1520     fprintf(stderr, "BN_cmp_word compared against 0 incorrectly.\n");
1521     return false;
1522   }
1523 
1524   if (!BN_set_word(r.get(), 100)) {
1525     return false;
1526   }
1527 
1528   if (BN_cmp_word(r.get(), 0) <= 0 ||
1529       BN_cmp_word(r.get(), 99) <= 0 ||
1530       BN_cmp_word(r.get(), 100) != 0 ||
1531       BN_cmp_word(r.get(), 101) >= 0 ||
1532       BN_cmp_word(r.get(), kMaxWord) >= 0) {
1533     fprintf(stderr, "BN_cmp_word compared against 100 incorrectly.\n");
1534     return false;
1535   }
1536 
1537   BN_set_negative(r.get(), 1);
1538 
1539   if (BN_cmp_word(r.get(), 0) >= 0 ||
1540       BN_cmp_word(r.get(), 100) >= 0 ||
1541       BN_cmp_word(r.get(), kMaxWord) >= 0) {
1542     fprintf(stderr, "BN_cmp_word compared against -100 incorrectly.\n");
1543     return false;
1544   }
1545 
1546   if (!BN_set_word(r.get(), kMaxWord)) {
1547     return false;
1548   }
1549 
1550   if (BN_cmp_word(r.get(), 0) <= 0 ||
1551       BN_cmp_word(r.get(), kMaxWord - 1) <= 0 ||
1552       BN_cmp_word(r.get(), kMaxWord) != 0) {
1553     fprintf(stderr, "BN_cmp_word compared against kMaxWord incorrectly.\n");
1554     return false;
1555   }
1556 
1557   if (!BN_add(r.get(), r.get(), BN_value_one())) {
1558     return false;
1559   }
1560 
1561   if (BN_cmp_word(r.get(), 0) <= 0 ||
1562       BN_cmp_word(r.get(), kMaxWord) <= 0) {
1563     fprintf(stderr, "BN_cmp_word compared against kMaxWord + 1 incorrectly.\n");
1564     return false;
1565   }
1566 
1567   BN_set_negative(r.get(), 1);
1568 
1569   if (BN_cmp_word(r.get(), 0) >= 0 ||
1570       BN_cmp_word(r.get(), kMaxWord) >= 0) {
1571     fprintf(stderr,
1572             "BN_cmp_word compared against -kMaxWord - 1 incorrectly.\n");
1573     return false;
1574   }
1575 
1576   return true;
1577 }
1578 
TestBN2Dec()1579 static bool TestBN2Dec() {
1580   static const char *kBN2DecTests[] = {
1581       "0",
1582       "1",
1583       "-1",
1584       "100",
1585       "-100",
1586       "123456789012345678901234567890",
1587       "-123456789012345678901234567890",
1588       "123456789012345678901234567890123456789012345678901234567890",
1589       "-123456789012345678901234567890123456789012345678901234567890",
1590   };
1591 
1592   for (const char *test : kBN2DecTests) {
1593     bssl::UniquePtr<BIGNUM> bn;
1594     int ret = DecimalToBIGNUM(&bn, test);
1595     if (ret == 0) {
1596       return false;
1597     }
1598 
1599     bssl::UniquePtr<char> dec(BN_bn2dec(bn.get()));
1600     if (!dec) {
1601       fprintf(stderr, "BN_bn2dec failed on %s.\n", test);
1602       return false;
1603     }
1604 
1605     if (strcmp(dec.get(), test) != 0) {
1606       fprintf(stderr, "BN_bn2dec gave %s, wanted %s.\n", dec.get(), test);
1607       return false;
1608     }
1609   }
1610 
1611   return true;
1612 }
1613 
TestBNSetGetU64()1614 static bool TestBNSetGetU64() {
1615   static const struct {
1616     const char *hex;
1617     uint64_t value;
1618   } kU64Tests[] = {
1619       {"0", UINT64_C(0x0)},
1620       {"1", UINT64_C(0x1)},
1621       {"ffffffff", UINT64_C(0xffffffff)},
1622       {"100000000", UINT64_C(0x100000000)},
1623       {"ffffffffffffffff", UINT64_C(0xffffffffffffffff)},
1624   };
1625 
1626   for (const auto& test : kU64Tests) {
1627     bssl::UniquePtr<BIGNUM> bn(BN_new()), expected;
1628     if (!bn ||
1629         !BN_set_u64(bn.get(), test.value) ||
1630         !HexToBIGNUM(&expected, test.hex) ||
1631         BN_cmp(bn.get(), expected.get()) != 0) {
1632       fprintf(stderr, "BN_set_u64 test failed for 0x%s.\n", test.hex);
1633       ERR_print_errors_fp(stderr);
1634       return false;
1635     }
1636 
1637     uint64_t tmp;
1638     if (!BN_get_u64(bn.get(), &tmp) || tmp != test.value) {
1639       fprintf(stderr, "BN_get_u64 test failed for 0x%s.\n", test.hex);
1640       return false;
1641     }
1642 
1643     BN_set_negative(bn.get(), 1);
1644     if (!BN_get_u64(bn.get(), &tmp) || tmp != test.value) {
1645       fprintf(stderr, "BN_get_u64 test failed for -0x%s.\n", test.hex);
1646       return false;
1647     }
1648   }
1649 
1650   // Test that BN_get_u64 fails on large numbers.
1651   bssl::UniquePtr<BIGNUM> bn(BN_new());
1652   if (!BN_lshift(bn.get(), BN_value_one(), 64)) {
1653     return false;
1654   }
1655 
1656   uint64_t tmp;
1657   if (BN_get_u64(bn.get(), &tmp)) {
1658     fprintf(stderr, "BN_get_u64 of 2^64 unexpectedly succeeded.\n");
1659     return false;
1660   }
1661 
1662   BN_set_negative(bn.get(), 1);
1663   if (BN_get_u64(bn.get(), &tmp)) {
1664     fprintf(stderr, "BN_get_u64 of -2^64 unexpectedly succeeded.\n");
1665     return false;
1666   }
1667 
1668   return true;
1669 }
1670 
TestBNPow2(BN_CTX * ctx)1671 static bool TestBNPow2(BN_CTX *ctx) {
1672   bssl::UniquePtr<BIGNUM>
1673       power_of_two(BN_new()),
1674       random(BN_new()),
1675       expected(BN_new()),
1676       actual(BN_new());
1677 
1678   if (!power_of_two.get() ||
1679       !random.get() ||
1680       !expected.get() ||
1681       !actual.get()) {
1682     return false;
1683   }
1684 
1685   // Choose an exponent.
1686   for (size_t e = 3; e < 512; e += 11) {
1687     // Choose a bit length for our randoms.
1688     for (int len = 3; len < 512; len += 23) {
1689       // Set power_of_two = 2^e.
1690       if (!BN_lshift(power_of_two.get(), BN_value_one(), (int) e)) {
1691         fprintf(stderr, "Failed to shiftl.\n");
1692         return false;
1693       }
1694 
1695       // Test BN_is_pow2 on power_of_two.
1696       if (!BN_is_pow2(power_of_two.get())) {
1697         fprintf(stderr, "BN_is_pow2 returned false for a power of two.\n");
1698         hexdump(stderr, "Arg: ", power_of_two->d,
1699                 power_of_two->top * sizeof(BN_ULONG));
1700         return false;
1701       }
1702 
1703       // Pick a large random value, ensuring it isn't a power of two.
1704       if (!BN_rand(random.get(), len, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ANY)) {
1705         fprintf(stderr, "Failed to generate random in TestBNPow2.\n");
1706         return false;
1707       }
1708 
1709       // Test BN_is_pow2 on |r|.
1710       if (BN_is_pow2(random.get())) {
1711         fprintf(stderr, "BN_is_pow2 returned true for a non-power of two.\n");
1712         hexdump(stderr, "Arg: ", random->d, random->top * sizeof(BN_ULONG));
1713         return false;
1714       }
1715 
1716       // Test BN_mod_pow2 on |r|.
1717       if (!BN_mod(expected.get(), random.get(), power_of_two.get(), ctx) ||
1718           !BN_mod_pow2(actual.get(), random.get(), e) ||
1719           BN_cmp(actual.get(), expected.get())) {
1720         fprintf(stderr, "BN_mod_pow2 returned the wrong value:\n");
1721         hexdump(stderr, "Expected: ", expected->d,
1722                 expected->top * sizeof(BN_ULONG));
1723         hexdump(stderr, "Got:      ", actual->d,
1724                 actual->top * sizeof(BN_ULONG));
1725         return false;
1726       }
1727 
1728       // Test BN_nnmod_pow2 on |r|.
1729       if (!BN_nnmod(expected.get(), random.get(), power_of_two.get(), ctx) ||
1730           !BN_nnmod_pow2(actual.get(), random.get(), e) ||
1731           BN_cmp(actual.get(), expected.get())) {
1732         fprintf(stderr, "BN_nnmod_pow2 failed on positive input:\n");
1733         hexdump(stderr, "Expected: ", expected->d,
1734                 expected->top * sizeof(BN_ULONG));
1735         hexdump(stderr, "Got:      ", actual->d,
1736                 actual->top * sizeof(BN_ULONG));
1737         return false;
1738       }
1739 
1740       // Test BN_nnmod_pow2 on -|r|.
1741       BN_set_negative(random.get(), 1);
1742       if (!BN_nnmod(expected.get(), random.get(), power_of_two.get(), ctx) ||
1743           !BN_nnmod_pow2(actual.get(), random.get(), e) ||
1744           BN_cmp(actual.get(), expected.get())) {
1745         fprintf(stderr, "BN_nnmod_pow2 failed on negative input:\n");
1746         hexdump(stderr, "Expected: ", expected->d,
1747                 expected->top * sizeof(BN_ULONG));
1748         hexdump(stderr, "Got:      ", actual->d,
1749                 actual->top * sizeof(BN_ULONG));
1750         return false;
1751       }
1752     }
1753   }
1754 
1755   return true;
1756 }
1757 
main(int argc,char * argv[])1758 int main(int argc, char *argv[]) {
1759   CRYPTO_library_init();
1760 
1761   if (argc != 2) {
1762     fprintf(stderr, "%s TEST_FILE\n", argv[0]);
1763     return 1;
1764   }
1765 
1766   bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
1767   if (!ctx) {
1768     return 1;
1769   }
1770 
1771   if (!TestBN2BinPadded(ctx.get()) ||
1772       !TestDec2BN(ctx.get()) ||
1773       !TestHex2BN(ctx.get()) ||
1774       !TestASC2BN(ctx.get()) ||
1775       !TestLittleEndian() ||
1776       !TestMPI() ||
1777       !TestRand() ||
1778       !TestASN1() ||
1779       !TestNegativeZero(ctx.get()) ||
1780       !TestBadModulus(ctx.get()) ||
1781       !TestExpModZero() ||
1782       !TestSmallPrime(ctx.get()) ||
1783       !TestCmpWord() ||
1784       !TestBN2Dec() ||
1785       !TestBNSetGetU64() ||
1786       !TestBNPow2(ctx.get())) {
1787     return 1;
1788   }
1789 
1790   return FileTestMain(RunTest, ctx.get(), argv[1]);
1791 }
1792