1 // Copyright 2015-2016 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //!     https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38 
39 use crate::{
40     arithmetic::montgomery::*,
41     bits, bssl, c, error,
42     limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43 };
44 use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45 use core::{
46     marker::PhantomData,
47     ops::{Deref, DerefMut},
48 };
49 
50 pub unsafe trait Prime {}
51 
52 struct Width<M> {
53     num_limbs: usize,
54 
55     /// The modulus *m* that the width originated from.
56     m: PhantomData<M>,
57 }
58 
59 /// All `BoxedLimbs<M>` are stored in the same number of limbs.
60 struct BoxedLimbs<M> {
61     limbs: Box<[Limb]>,
62 
63     /// The modulus *m* that determines the size of `limbx`.
64     m: PhantomData<M>,
65 }
66 
67 impl<M> Deref for BoxedLimbs<M> {
68     type Target = [Limb];
69     #[inline]
deref(&self) -> &Self::Target70     fn deref(&self) -> &Self::Target {
71         &self.limbs
72     }
73 }
74 
75 impl<M> DerefMut for BoxedLimbs<M> {
76     #[inline]
deref_mut(&mut self) -> &mut Self::Target77     fn deref_mut(&mut self) -> &mut Self::Target {
78         &mut self.limbs
79     }
80 }
81 
82 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
83 // is resolved or restrict `M: Clone`.
84 impl<M> Clone for BoxedLimbs<M> {
clone(&self) -> Self85     fn clone(&self) -> Self {
86         Self {
87             limbs: self.limbs.clone(),
88             m: self.m,
89         }
90     }
91 }
92 
93 impl<M> BoxedLimbs<M> {
positive_minimal_width_from_be_bytes( input: untrusted::Input, ) -> Result<Self, error::KeyRejected>94     fn positive_minimal_width_from_be_bytes(
95         input: untrusted::Input,
96     ) -> Result<Self, error::KeyRejected> {
97         // Reject leading zeros. Also reject the value zero ([0]) because zero
98         // isn't positive.
99         if untrusted::Reader::new(input).peek(0) {
100             return Err(error::KeyRejected::invalid_encoding());
101         }
102         let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
103         let mut r = Self::zero(Width {
104             num_limbs,
105             m: PhantomData,
106         });
107         limb::parse_big_endian_and_pad_consttime(input, &mut r)
108             .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
109         Ok(r)
110     }
111 
minimal_width_from_unpadded(limbs: &[Limb]) -> Self112     fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
113         debug_assert_ne!(limbs.last(), Some(&0));
114         Self {
115             limbs: limbs.to_owned().into_boxed_slice(),
116             m: PhantomData,
117         }
118     }
119 
from_be_bytes_padded_less_than( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>120     fn from_be_bytes_padded_less_than(
121         input: untrusted::Input,
122         m: &Modulus<M>,
123     ) -> Result<Self, error::Unspecified> {
124         let mut r = Self::zero(m.width());
125         limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
126         if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
127             return Err(error::Unspecified);
128         }
129         Ok(r)
130     }
131 
132     #[inline]
is_zero(&self) -> bool133     fn is_zero(&self) -> bool {
134         limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
135     }
136 
zero(width: Width<M>) -> Self137     fn zero(width: Width<M>) -> Self {
138         Self {
139             limbs: vec![0; width.num_limbs].into_boxed_slice(),
140             m: PhantomData,
141         }
142     }
143 
width(&self) -> Width<M>144     fn width(&self) -> Width<M> {
145         Width {
146             num_limbs: self.limbs.len(),
147             m: PhantomData,
148         }
149     }
150 }
151 
152 /// A modulus *s* that is smaller than another modulus *l* so every element of
153 /// ℤ/sℤ is also an element of ℤ/lℤ.
154 pub unsafe trait SmallerModulus<L> {}
155 
156 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
157 /// the precondition for reduction by conditional subtraction,
158 /// `elem_reduce_once()`.
159 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
160 
161 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
162 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
163 /// ℤ/sℤ.
164 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
165 
166 pub unsafe trait PublicModulus {}
167 
168 /// The x86 implementation of `GFp_bn_mul_mont`, at least, requires at least 4
169 /// limbs. For a long time we have required 4 limbs for all targets, though
170 /// this may be unnecessary. TODO: Replace this with
171 /// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
172 /// same.
173 pub const MODULUS_MIN_LIMBS: usize = 4;
174 
175 pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
176 
177 /// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
178 /// for efficient Montgomery multiplication modulo *m*. The value must be odd
179 /// and larger than 2. The larger-than-1 requirement is imposed, at least, by
180 /// the modular inversion code.
181 pub struct Modulus<M> {
182     limbs: BoxedLimbs<M>, // Also `value >= 3`.
183 
184     // n0 * N == -1 (mod r).
185     //
186     // r == 2**(N0_LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
187     // ensures that we can do integer division by |r| by simply ignoring
188     // `N0_LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
189     // just looking at the lowest `N0_LIMBS_USED` limbs. This is what makes
190     // Montgomery multiplication efficient.
191     //
192     // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
193     // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
194     // multi-limb Montgomery multiplication of a * b (mod n), given the
195     // unreduced product t == a * b, we repeatedly calculate:
196     //
197     //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
198     //    t2 := t1*n0*n
199     //    t3 := t + t2
200     //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
201     //
202     // In the last step, it would only make sense to ignore the lowest limb of
203     // |t3| if it were zero. The middle steps ensure that this is the case:
204     //
205     //                            t3 ==  0 (mod r)
206     //                        t + t2 ==  0 (mod r)
207     //                   t + t1*n0*n ==  0 (mod r)
208     //                       t1*n0*n == -t (mod r)
209     //                        t*n0*n == -t (mod r)
210     //                          n0*n == -1 (mod r)
211     //                            n0 == -1/n (mod r)
212     //
213     // Thus, in each iteration of the loop, we multiply by the constant factor
214     // n0, the negative inverse of n (mod r).
215     //
216     // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
217     // ones that don't, we could use a shorter `R` value and use faster `Limb`
218     // calculations instead of double-precision `u64` calculations.
219     n0: N0,
220 
221     oneRR: One<M, RR>,
222 }
223 
224 impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error>225     fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
226         fmt.debug_struct("Modulus")
227             // TODO: Print modulus value.
228             .finish()
229     }
230 }
231 
232 impl<M> Modulus<M> {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::KeyRejected>233     pub fn from_be_bytes_with_bit_length(
234         input: untrusted::Input,
235     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
236         let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
237         Self::from_boxed_limbs(limbs)
238     }
239 
from_nonnegative_with_bit_length( n: Nonnegative, ) -> Result<(Self, bits::BitLength), error::KeyRejected>240     pub fn from_nonnegative_with_bit_length(
241         n: Nonnegative,
242     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
243         let limbs = BoxedLimbs {
244             limbs: n.limbs.into_boxed_slice(),
245             m: PhantomData,
246         };
247         Self::from_boxed_limbs(limbs)
248     }
249 
from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected>250     fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
251         if n.len() > MODULUS_MAX_LIMBS {
252             return Err(error::KeyRejected::too_large());
253         }
254         if n.len() < MODULUS_MIN_LIMBS {
255             return Err(error::KeyRejected::unexpected_error());
256         }
257         if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
258             return Err(error::KeyRejected::invalid_component());
259         }
260         if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
261             return Err(error::KeyRejected::unexpected_error());
262         }
263 
264         // n_mod_r = n % r. As explained in the documentation for `n0`, this is
265         // done by taking the lowest `N0_LIMBS_USED` limbs of `n`.
266         let n0 = {
267             extern "C" {
268                 fn GFp_bn_neg_inv_mod_r_u64(n: u64) -> u64;
269             }
270 
271             // XXX: u64::from isn't guaranteed to be constant time.
272             let mut n_mod_r: u64 = u64::from(n[0]);
273 
274             if N0_LIMBS_USED == 2 {
275                 // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
276                 // fail to compile because of `deny(exceeding_bitshifts)`.
277                 debug_assert_eq!(LIMB_BITS, 32);
278                 n_mod_r |= u64::from(n[1]) << 32;
279             }
280             N0::from(unsafe { GFp_bn_neg_inv_mod_r_u64(n_mod_r) })
281         };
282 
283         let bits = limb::limbs_minimal_bits(&n.limbs);
284         let oneRR = {
285             let partial = PartialModulus {
286                 limbs: &n.limbs,
287                 n0: n0.clone(),
288                 m: PhantomData,
289             };
290 
291             One::newRR(&partial, bits)
292         };
293 
294         Ok((
295             Self {
296                 limbs: n,
297                 n0,
298                 oneRR,
299             },
300             bits,
301         ))
302     }
303 
304     #[inline]
width(&self) -> Width<M>305     fn width(&self) -> Width<M> {
306         self.limbs.width()
307     }
308 
zero<E>(&self) -> Elem<M, E>309     fn zero<E>(&self) -> Elem<M, E> {
310         Elem {
311             limbs: BoxedLimbs::zero(self.width()),
312             encoding: PhantomData,
313         }
314     }
315 
316     // TODO: Get rid of this
one(&self) -> Elem<M, Unencoded>317     fn one(&self) -> Elem<M, Unencoded> {
318         let mut r = self.zero();
319         r.limbs[0] = 1;
320         r
321     }
322 
oneRR(&self) -> &One<M, RR>323     pub fn oneRR(&self) -> &One<M, RR> {
324         &self.oneRR
325     }
326 
to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded> where M: SmallerModulus<L>,327     pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
328     where
329         M: SmallerModulus<L>,
330     {
331         // TODO: Encode this assertion into the `where` above.
332         assert_eq!(self.width().num_limbs, l.width().num_limbs);
333         let limbs = self.limbs.clone();
334         Elem {
335             limbs: BoxedLimbs {
336                 limbs: limbs.limbs,
337                 m: PhantomData,
338             },
339             encoding: PhantomData,
340         }
341     }
342 
as_partial(&self) -> PartialModulus<M>343     fn as_partial(&self) -> PartialModulus<M> {
344         PartialModulus {
345             limbs: &self.limbs,
346             n0: self.n0.clone(),
347             m: PhantomData,
348         }
349     }
350 }
351 
352 struct PartialModulus<'a, M> {
353     limbs: &'a [Limb],
354     n0: N0,
355     m: PhantomData<M>,
356 }
357 
358 impl<M> PartialModulus<'_, M> {
359     // TODO: XXX Avoid duplication with `Modulus`.
zero(&self) -> Elem<M, R>360     fn zero(&self) -> Elem<M, R> {
361         let width = Width {
362             num_limbs: self.limbs.len(),
363             m: PhantomData,
364         };
365         Elem {
366             limbs: BoxedLimbs::zero(width),
367             encoding: PhantomData,
368         }
369     }
370 }
371 
372 /// Elements of ℤ/mℤ for some modulus *m*.
373 //
374 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
375 // submodule. However, for maximum clarity, we always explicitly use
376 // `Unencoded` within the `bigint` submodule.
377 pub struct Elem<M, E = Unencoded> {
378     limbs: BoxedLimbs<M>,
379 
380     /// The number of Montgomery factors that need to be canceled out from
381     /// `value` to get the actual value.
382     encoding: PhantomData<E>,
383 }
384 
385 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
386 // is resolved or restrict `M: Clone` and `E: Clone`.
387 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self388     fn clone(&self) -> Self {
389         Self {
390             limbs: self.limbs.clone(),
391             encoding: self.encoding,
392         }
393     }
394 }
395 
396 impl<M, E> Elem<M, E> {
397     #[inline]
is_zero(&self) -> bool398     pub fn is_zero(&self) -> bool {
399         self.limbs.is_zero()
400     }
401 }
402 
403 impl<M, E: ReductionEncoding> Elem<M, E> {
decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output>404     fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
405         // A multiplication isn't required since we're multiplying by the
406         // unencoded value one (1); only a Montgomery reduction is needed.
407         // However the only non-multiplication Montgomery reduction function we
408         // have requires the input to be large, so we avoid using it here.
409         let mut limbs = self.limbs;
410         let num_limbs = m.width().num_limbs;
411         let mut one = [0; MODULUS_MAX_LIMBS];
412         one[0] = 1;
413         let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
414         limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
415         Elem {
416             limbs,
417             encoding: PhantomData,
418         }
419     }
420 }
421 
422 impl<M> Elem<M, R> {
423     #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>424     pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
425         self.decode_once(m)
426     }
427 }
428 
429 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>430     pub fn from_be_bytes_padded(
431         input: untrusted::Input,
432         m: &Modulus<M>,
433     ) -> Result<Self, error::Unspecified> {
434         Ok(Elem {
435             limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
436             encoding: PhantomData,
437         })
438     }
439 
440     #[inline]
fill_be_bytes(&self, out: &mut [u8])441     pub fn fill_be_bytes(&self, out: &mut [u8]) {
442         // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
443         limb::big_endian_from_limbs(&self.limbs, out)
444     }
445 
into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected>446     pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
447         let (m, _bits) =
448             Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
449         Ok(m)
450     }
451 
is_one(&self) -> bool452     fn is_one(&self) -> bool {
453         limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
454     }
455 }
456 
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,457 pub fn elem_mul<M, AF, BF>(
458     a: &Elem<M, AF>,
459     b: Elem<M, BF>,
460     m: &Modulus<M>,
461 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
462 where
463     (AF, BF): ProductEncoding,
464 {
465     elem_mul_(a, b, &m.as_partial())
466 }
467 
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,468 fn elem_mul_<M, AF, BF>(
469     a: &Elem<M, AF>,
470     mut b: Elem<M, BF>,
471     m: &PartialModulus<M>,
472 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
473 where
474     (AF, BF): ProductEncoding,
475 {
476     limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
477     Elem {
478         limbs: b.limbs,
479         encoding: PhantomData,
480     }
481 }
482 
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)483 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
484     extern "C" {
485         fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
486     }
487     unsafe {
488         LIMBS_shl_mod(
489             a.limbs.as_mut_ptr(),
490             a.limbs.as_ptr(),
491             m.limbs.as_ptr(),
492             m.limbs.len(),
493         );
494     }
495 }
496 
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>497 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
498     a: &Elem<Larger, Unencoded>,
499     m: &Modulus<Smaller>,
500 ) -> Elem<Smaller, Unencoded> {
501     let mut r = a.limbs.clone();
502     assert!(r.len() <= m.limbs.len());
503     limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
504     Elem {
505         limbs: BoxedLimbs {
506             limbs: r.limbs,
507             m: PhantomData,
508         },
509         encoding: PhantomData,
510     }
511 }
512 
513 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, RInverse>514 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
515     a: &Elem<Larger, Unencoded>,
516     m: &Modulus<Smaller>,
517 ) -> Elem<Smaller, RInverse> {
518     let mut tmp = [0; MODULUS_MAX_LIMBS];
519     let tmp = &mut tmp[..a.limbs.len()];
520     tmp.copy_from_slice(&a.limbs);
521 
522     let mut r = m.zero();
523     limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
524     r
525 }
526 
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,527 fn elem_squared<M, E>(
528     mut a: Elem<M, E>,
529     m: &PartialModulus<M>,
530 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
531 where
532     (E, E): ProductEncoding,
533 {
534     limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
535     Elem {
536         limbs: a.limbs,
537         encoding: PhantomData,
538     }
539 }
540 
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>541 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
542     a: Elem<Smaller, Unencoded>,
543     m: &Modulus<Larger>,
544 ) -> Elem<Larger, Unencoded> {
545     let mut r = m.zero();
546     r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
547     r
548 }
549 
550 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>551 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
552     extern "C" {
553         // `r` and `a` may alias.
554         fn LIMBS_add_mod(
555             r: *mut Limb,
556             a: *const Limb,
557             b: *const Limb,
558             m: *const Limb,
559             num_limbs: c::size_t,
560         );
561     }
562     unsafe {
563         LIMBS_add_mod(
564             a.limbs.as_mut_ptr(),
565             a.limbs.as_ptr(),
566             b.limbs.as_ptr(),
567             m.limbs.as_ptr(),
568             m.limbs.len(),
569         )
570     }
571     a
572 }
573 
574 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>575 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
576     extern "C" {
577         // `r` and `a` may alias.
578         fn LIMBS_sub_mod(
579             r: *mut Limb,
580             a: *const Limb,
581             b: *const Limb,
582             m: *const Limb,
583             num_limbs: c::size_t,
584         );
585     }
586     unsafe {
587         LIMBS_sub_mod(
588             a.limbs.as_mut_ptr(),
589             a.limbs.as_ptr(),
590             b.limbs.as_ptr(),
591             m.limbs.as_ptr(),
592             m.limbs.len(),
593         );
594     }
595     a
596 }
597 
598 // The value 1, Montgomery-encoded some number of times.
599 pub struct One<M, E>(Elem<M, E>);
600 
601 impl<M> One<M, RR> {
602     // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
603     // 2**LIMB_BITS such that R > m.
604     //
605     // Even though the assembly on some 32-bit platforms works with 64-bit
606     // values, using `LIMB_BITS` here, rather than `N0_LIMBS_USED * LIMB_BITS`,
607     // is correct because R**2 will still be a multiple of the latter as
608     // `N0_LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self609     fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
610         let m_bits = m_bits.as_usize_bits();
611         let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
612 
613         // base = 2**(lg m - 1).
614         let bit = m_bits - 1;
615         let mut base = m.zero();
616         base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
617 
618         // Double `base` so that base == R == 2**r (mod m). For normal moduli
619         // that have the high bit of the highest limb set, this requires one
620         // doubling. Unusual moduli require more doublings but we are less
621         // concerned about the performance of those.
622         //
623         // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
624         // Montgomery form (`elem_exp_vartime_()` requires the base to be in
625         // Montgomery form). Then compute
626         // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
627         //
628         // Take advantage of the fact that `elem_mul_by_2` is faster than
629         // `elem_squared` by replacing some of the early squarings with shifts.
630         // TODO: Benchmark shift vs. squaring performance to determine the
631         // optimal value of `lg_base`.
632         let lg_base = 2usize; // Shifts vs. squaring trade-off.
633         debug_assert_eq!(lg_base.count_ones(), 1); // Must 2**n for n >= 0.
634         let shifts = r - bit + lg_base;
635         let exponent = (r / lg_base) as u64;
636         for _ in 0..shifts {
637             elem_mul_by_2(&mut base, m)
638         }
639         let RR = elem_exp_vartime_(base, exponent, m);
640 
641         Self(Elem {
642             limbs: RR.limbs,
643             encoding: PhantomData, // PhantomData<RR>
644         })
645     }
646 }
647 
648 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>649     fn as_ref(&self) -> &Elem<M, E> {
650         &self.0
651     }
652 }
653 
654 /// A non-secret odd positive value in the range
655 /// [3, PUBLIC_EXPONENT_MAX_VALUE].
656 #[derive(Clone, Copy, Debug)]
657 pub struct PublicExponent(u64);
658 
659 impl PublicExponent {
from_be_bytes( input: untrusted::Input, min_value: u64, ) -> Result<Self, error::KeyRejected>660     pub fn from_be_bytes(
661         input: untrusted::Input,
662         min_value: u64,
663     ) -> Result<Self, error::KeyRejected> {
664         if input.len() > 5 {
665             return Err(error::KeyRejected::too_large());
666         }
667         let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
668             // The exponent can't be zero and it can't be prefixed with
669             // zero-valued bytes.
670             if input.peek(0) {
671                 return Err(error::KeyRejected::invalid_encoding());
672             }
673             let mut value = 0u64;
674             loop {
675                 let byte = input
676                     .read_byte()
677                     .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
678                 value = (value << 8) | u64::from(byte);
679                 if input.at_end() {
680                     return Ok(value);
681                 }
682             }
683         })?;
684 
685         // Step 2 / Step b. NIST SP800-89 defers to FIPS 186-3, which requires
686         // `e >= 65537`. We enforce this when signing, but are more flexible in
687         // verification, for compatibility. Only small public exponents are
688         // supported.
689         if value & 1 != 1 {
690             return Err(error::KeyRejected::invalid_component());
691         }
692         debug_assert!(min_value & 1 == 1);
693         debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
694         if min_value < 3 {
695             return Err(error::KeyRejected::invalid_component());
696         }
697         if value < min_value {
698             return Err(error::KeyRejected::too_small());
699         }
700         if value > PUBLIC_EXPONENT_MAX_VALUE {
701             return Err(error::KeyRejected::too_large());
702         }
703 
704         Ok(Self(value))
705     }
706 }
707 
708 // This limit was chosen to bound the performance of the simple
709 // exponentiation-by-squaring implementation in `elem_exp_vartime`. In
710 // particular, it helps mitigate theoretical resource exhaustion attacks. 33
711 // bits was chosen as the limit based on the recommendations in [1] and
712 // [2]. Windows CryptoAPI (at least older versions) doesn't support values
713 // larger than 32 bits [3], so it is unlikely that exponents larger than 32
714 // bits are being used for anything Windows commonly does.
715 //
716 // [1] https://www.imperialviolet.org/2012/03/16/rsae.html
717 // [2] https://www.imperialviolet.org/2012/03/17/rsados.html
718 // [3] https://msdn.microsoft.com/en-us/library/aa387685(VS.85).aspx
719 const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
720 
721 /// Calculates base**exponent (mod m).
722 // TODO: The test coverage needs to be expanded, e.g. test with the largest
723 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, Unencoded>, PublicExponent(exponent): PublicExponent, m: &Modulus<M>, ) -> Elem<M, R>724 pub fn elem_exp_vartime<M>(
725     base: Elem<M, Unencoded>,
726     PublicExponent(exponent): PublicExponent,
727     m: &Modulus<M>,
728 ) -> Elem<M, R> {
729     let base = elem_mul(m.oneRR().as_ref(), base, &m);
730     elem_exp_vartime_(base, exponent, &m.as_partial())
731 }
732 
733 /// Calculates base**exponent (mod m).
elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R>734 fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
735     // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
736     // square-and-multiply that scans the exponent from the most significant
737     // bit to the least significant bit (left-to-right). Left-to-right requires
738     // less storage compared to right-to-left scanning, at the cost of needing
739     // to compute `exponent.leading_zeros()`, which we assume to be cheap.
740     //
741     // During RSA public key operations the exponent is almost always either 65537
742     // (0b10000000000000001) or 3 (0b11), both of which have a Hamming weight
743     // of 2. During Montgomery setup the exponent is almost always a power of two,
744     // with Hamming weight 1. As explained in [Knuth], exponentiation by squaring
745     // is the most efficient algorithm when the Hamming weight is 2 or less. It
746     // isn't the most efficient for all other, uncommon, exponent values but any
747     // suboptimality is bounded by `PUBLIC_EXPONENT_MAX_VALUE`.
748     //
749     // This implementation is slightly simplified by taking advantage of the
750     // fact that we require the exponent to be a positive integer.
751     //
752     // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
753     //          Algorithms (3rd Edition), Section 4.6.3.
754     assert!(exponent >= 1);
755     assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
756     let mut acc = base.clone();
757     let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
758     debug_assert!((exponent & bit) != 0);
759     while bit > 1 {
760         bit >>= 1;
761         acc = elem_squared(acc, m);
762         if (exponent & bit) != 0 {
763             acc = elem_mul_(&base, acc, m);
764         }
765     }
766     acc
767 }
768 
769 // `M` represents the prime modulus for which the exponent is in the interval
770 // [1, `m` - 1).
771 pub struct PrivateExponent<M> {
772     limbs: BoxedLimbs<M>,
773 }
774 
775 impl<M> PrivateExponent<M> {
from_be_bytes_padded( input: untrusted::Input, p: &Modulus<M>, ) -> Result<Self, error::Unspecified>776     pub fn from_be_bytes_padded(
777         input: untrusted::Input,
778         p: &Modulus<M>,
779     ) -> Result<Self, error::Unspecified> {
780         let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
781 
782         // Proof that `dP < p - 1`:
783         //
784         // If `dP < p` then either `dP == p - 1` or `dP < p - 1`. Since `p` is
785         // odd, `p - 1` is even. `d` is odd, and an odd number modulo an even
786         // number is odd. Therefore `dP` must be odd. But then it cannot be
787         // `p - 1` and so we know `dP < p - 1`.
788         //
789         // Further we know `dP != 0` because `dP` is not even.
790         if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
791             return Err(error::Unspecified);
792         }
793 
794         Ok(Self { limbs: dP })
795     }
796 }
797 
798 impl<M: Prime> PrivateExponent<M> {
799     // Returns `p - 2`.
for_flt(p: &Modulus<M>) -> Self800     fn for_flt(p: &Modulus<M>) -> Self {
801         let two = elem_add(p.one(), p.one(), p);
802         let p_minus_2 = elem_sub(p.zero(), &two, p);
803         Self {
804             limbs: p_minus_2.limbs,
805         }
806     }
807 }
808 
809 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>810 pub fn elem_exp_consttime<M>(
811     base: Elem<M, R>,
812     exponent: &PrivateExponent<M>,
813     m: &Modulus<M>,
814 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
815     use crate::limb::Window;
816 
817     const WINDOW_BITS: usize = 5;
818     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
819 
820     let num_limbs = m.limbs.len();
821 
822     let mut table = vec![0; TABLE_ENTRIES * num_limbs];
823 
824     fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
825         extern "C" {
826             fn LIMBS_select_512_32(
827                 r: *mut Limb,
828                 table: *const Limb,
829                 num_limbs: c::size_t,
830                 i: Window,
831             ) -> bssl::Result;
832         }
833         Result::from(unsafe {
834             LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
835         })
836         .unwrap();
837     }
838 
839     fn power<M>(
840         table: &[Limb],
841         i: Window,
842         mut acc: Elem<M, R>,
843         mut tmp: Elem<M, R>,
844         m: &Modulus<M>,
845     ) -> (Elem<M, R>, Elem<M, R>) {
846         for _ in 0..WINDOW_BITS {
847             acc = elem_squared(acc, &m.as_partial());
848         }
849         gather(table, i, &mut tmp);
850         let acc = elem_mul(&tmp, acc, m);
851         (acc, tmp)
852     }
853 
854     let tmp = m.one();
855     let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
856 
857     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
858         &table[(i * num_limbs)..][..num_limbs]
859     }
860     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
861         &mut table[(i * num_limbs)..][..num_limbs]
862     }
863     let num_limbs = m.limbs.len();
864     entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
865     entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
866     for i in 2..TABLE_ENTRIES {
867         let (src1, src2) = if i % 2 == 0 {
868             (i / 2, i / 2)
869         } else {
870             (i - 1, 1)
871         };
872         let (previous, rest) = table.split_at_mut(num_limbs * i);
873         let src1 = entry(previous, src1, num_limbs);
874         let src2 = entry(previous, src2, num_limbs);
875         let dst = entry_mut(rest, 0, num_limbs);
876         limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
877     }
878 
879     let (r, _) = limb::fold_5_bit_windows(
880         &exponent.limbs,
881         |initial_window| {
882             let mut r = Elem {
883                 limbs: base.limbs,
884                 encoding: PhantomData,
885             };
886             gather(&table, initial_window, &mut r);
887             (r, tmp)
888         },
889         |(acc, tmp), window| power(&table, window, acc, tmp, m),
890     );
891 
892     let r = r.into_unencoded(m);
893 
894     Ok(r)
895 }
896 
897 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>898 pub fn elem_inverse_consttime<M: Prime>(
899     a: Elem<M, R>,
900     m: &Modulus<M>,
901 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
902     elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
903 }
904 
905 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>906 pub fn elem_exp_consttime<M>(
907     base: Elem<M, R>,
908     exponent: &PrivateExponent<M>,
909     m: &Modulus<M>,
910 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
911     // The x86_64 assembly was written under the assumption that the input data
912     // is aligned to `MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH` bytes, which was/is
913     // 64 in OpenSSL. Similarly, OpenSSL uses the x86_64 assembly functions by
914     // giving it only inputs `tmp`, `am`, and `np` that immediately follow the
915     // table. The code seems to "work" even when the inputs aren't exactly
916     // like that but the side channel defenses might not be as effective. All
917     // the awkwardness here stems from trying to use the assembly code like
918     // OpenSSL does.
919 
920     use crate::limb::Window;
921 
922     const WINDOW_BITS: usize = 5;
923     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
924 
925     let num_limbs = m.limbs.len();
926 
927     const ALIGNMENT: usize = 64;
928     assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
929     let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
930     let (table, state) = {
931         let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
932         let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
933         assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
934         table.split_at_mut(TABLE_ENTRIES * num_limbs)
935     };
936 
937     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
938         &table[(i * num_limbs)..][..num_limbs]
939     }
940     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
941         &mut table[(i * num_limbs)..][..num_limbs]
942     }
943 
944     const ACC: usize = 0; // `tmp` in OpenSSL
945     const BASE: usize = ACC + 1; // `am` in OpenSSL
946     const M: usize = BASE + 1; // `np` in OpenSSL
947 
948     entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
949     entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
950 
951     fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
952         extern "C" {
953             fn GFp_bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
954         }
955         unsafe {
956             GFp_bn_scatter5(
957                 entry(state, ACC, num_limbs).as_ptr(),
958                 num_limbs,
959                 table.as_mut_ptr(),
960                 i,
961             )
962         }
963     }
964 
965     fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
966         extern "C" {
967             fn GFp_bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
968         }
969         unsafe {
970             GFp_bn_gather5(
971                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
972                 num_limbs,
973                 table.as_ptr(),
974                 i,
975             )
976         }
977     }
978 
979     fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
980         gather(table, state, i, num_limbs);
981         assert_eq!(ACC, 0);
982         let (acc, rest) = state.split_at_mut(num_limbs);
983         let m = entry(rest, M - 1, num_limbs);
984         limbs_mont_square(acc, m, n0);
985     }
986 
987     fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
988         extern "C" {
989             fn GFp_bn_mul_mont_gather5(
990                 rp: *mut Limb,
991                 ap: *const Limb,
992                 table: *const Limb,
993                 np: *const Limb,
994                 n0: &N0,
995                 num: c::size_t,
996                 power: Window,
997             );
998         }
999         unsafe {
1000             GFp_bn_mul_mont_gather5(
1001                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1002                 entry(state, BASE, num_limbs).as_ptr(),
1003                 table.as_ptr(),
1004                 entry(state, M, num_limbs).as_ptr(),
1005                 n0,
1006                 num_limbs,
1007                 i,
1008             );
1009         }
1010     }
1011 
1012     fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1013         extern "C" {
1014             fn GFp_bn_power5(
1015                 r: *mut Limb,
1016                 a: *const Limb,
1017                 table: *const Limb,
1018                 n: *const Limb,
1019                 n0: &N0,
1020                 num: c::size_t,
1021                 i: Window,
1022             );
1023         }
1024         unsafe {
1025             GFp_bn_power5(
1026                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1027                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1028                 table.as_ptr(),
1029                 entry(state, M, num_limbs).as_ptr(),
1030                 n0,
1031                 num_limbs,
1032                 i,
1033             );
1034         }
1035     }
1036 
1037     // table[0] = base**0.
1038     {
1039         let acc = entry_mut(state, ACC, num_limbs);
1040         acc[0] = 1;
1041         limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1042     }
1043     scatter(table, state, 0, num_limbs);
1044 
1045     // table[1] = base**1.
1046     entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1047     scatter(table, state, 1, num_limbs);
1048 
1049     for i in 2..(TABLE_ENTRIES as Window) {
1050         if i % 2 == 0 {
1051             // TODO: Optimize this to avoid gathering
1052             gather_square(table, state, &m.n0, i / 2, num_limbs);
1053         } else {
1054             gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1055         };
1056         scatter(table, state, i, num_limbs);
1057     }
1058 
1059     let state = limb::fold_5_bit_windows(
1060         &exponent.limbs,
1061         |initial_window| {
1062             gather(table, state, initial_window, num_limbs);
1063             state
1064         },
1065         |state, window| {
1066             power(table, state, &m.n0, window, num_limbs);
1067             state
1068         },
1069     );
1070 
1071     extern "C" {
1072         fn GFp_bn_from_montgomery(
1073             r: *mut Limb,
1074             a: *const Limb,
1075             not_used: *const Limb,
1076             n: *const Limb,
1077             n0: &N0,
1078             num: c::size_t,
1079         ) -> bssl::Result;
1080     }
1081     Result::from(unsafe {
1082         GFp_bn_from_montgomery(
1083             entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1084             entry(state, ACC, num_limbs).as_ptr(),
1085             core::ptr::null(),
1086             entry(state, M, num_limbs).as_ptr(),
1087             &m.n0,
1088             num_limbs,
1089         )
1090     })?;
1091     let mut r = Elem {
1092         limbs: base.limbs,
1093         encoding: PhantomData,
1094     };
1095     r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1096     Ok(r)
1097 }
1098 
1099 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>1100 pub fn verify_inverses_consttime<M>(
1101     a: &Elem<M, R>,
1102     b: Elem<M, Unencoded>,
1103     m: &Modulus<M>,
1104 ) -> Result<(), error::Unspecified> {
1105     if elem_mul(a, b, m).is_one() {
1106         Ok(())
1107     } else {
1108         Err(error::Unspecified)
1109     }
1110 }
1111 
1112 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>1113 pub fn elem_verify_equal_consttime<M, E>(
1114     a: &Elem<M, E>,
1115     b: &Elem<M, E>,
1116 ) -> Result<(), error::Unspecified> {
1117     if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1118         Ok(())
1119     } else {
1120         Err(error::Unspecified)
1121     }
1122 }
1123 
1124 /// Nonnegative integers.
1125 pub struct Nonnegative {
1126     limbs: Vec<Limb>,
1127 }
1128 
1129 impl Nonnegative {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::Unspecified>1130     pub fn from_be_bytes_with_bit_length(
1131         input: untrusted::Input,
1132     ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1133         let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1134         // Rejects empty inputs.
1135         limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1136         while limbs.last() == Some(&0) {
1137             let _ = limbs.pop();
1138         }
1139         let r_bits = limb::limbs_minimal_bits(&limbs);
1140         Ok((Self { limbs }, r_bits))
1141     }
1142 
1143     #[inline]
is_odd(&self) -> bool1144     pub fn is_odd(&self) -> bool {
1145         limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1146     }
1147 
verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified>1148     pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1149         if !greater_than(other, self) {
1150             return Err(error::Unspecified);
1151         }
1152         Ok(())
1153     }
1154 
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>1155     pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1156         self.verify_less_than_modulus(&m)?;
1157         let mut r = m.zero();
1158         r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1159         Ok(r)
1160     }
1161 
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>1162     pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1163         if self.limbs.len() > m.limbs.len() {
1164             return Err(error::Unspecified);
1165         }
1166         if self.limbs.len() == m.limbs.len() {
1167             if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1168                 return Err(error::Unspecified);
1169             }
1170         }
1171         Ok(())
1172     }
1173 }
1174 
1175 // Returns a > b.
greater_than(a: &Nonnegative, b: &Nonnegative) -> bool1176 fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1177     if a.limbs.len() == b.limbs.len() {
1178         limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1179     } else {
1180         a.limbs.len() > b.limbs.len()
1181     }
1182 }
1183 
1184 #[derive(Clone)]
1185 #[repr(transparent)]
1186 struct N0([Limb; 2]);
1187 
1188 const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1189 
1190 impl From<u64> for N0 {
1191     #[inline]
from(n0: u64) -> Self1192     fn from(n0: u64) -> Self {
1193         #[cfg(target_pointer_width = "64")]
1194         {
1195             Self([n0, 0])
1196         }
1197 
1198         #[cfg(target_pointer_width = "32")]
1199         {
1200             Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1201         }
1202     }
1203 }
1204 
1205 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0)1206 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1207     debug_assert_eq!(r.len(), m.len());
1208     debug_assert_eq!(a.len(), m.len());
1209 
1210     #[cfg(any(
1211         target_arch = "aarch64",
1212         target_arch = "arm",
1213         target_arch = "x86_64",
1214         target_arch = "x86"
1215     ))]
1216     unsafe {
1217         GFp_bn_mul_mont(
1218             r.as_mut_ptr(),
1219             r.as_ptr(),
1220             a.as_ptr(),
1221             m.as_ptr(),
1222             n0,
1223             r.len(),
1224         )
1225     }
1226 
1227     #[cfg(not(any(
1228         target_arch = "aarch64",
1229         target_arch = "arm",
1230         target_arch = "x86_64",
1231         target_arch = "x86"
1232     )))]
1233     {
1234         let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1235         let tmp = &mut tmp[..(2 * a.len())];
1236         limbs_mul(tmp, r, a);
1237         limbs_from_mont_in_place(r, tmp, m, n0);
1238     }
1239 }
1240 
limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0)1241 fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
1242     extern "C" {
1243         fn GFp_bn_from_montgomery_in_place(
1244             r: *mut Limb,
1245             num_r: c::size_t,
1246             a: *mut Limb,
1247             num_a: c::size_t,
1248             n: *const Limb,
1249             num_n: c::size_t,
1250             n0: &N0,
1251         ) -> bssl::Result;
1252     }
1253     Result::from(unsafe {
1254         GFp_bn_from_montgomery_in_place(
1255             r.as_mut_ptr(),
1256             r.len(),
1257             tmp.as_mut_ptr(),
1258             tmp.len(),
1259             m.as_ptr(),
1260             m.len(),
1261             &n0,
1262         )
1263     })
1264     .unwrap()
1265 }
1266 
1267 #[cfg(not(any(
1268     target_arch = "aarch64",
1269     target_arch = "arm",
1270     target_arch = "x86_64",
1271     target_arch = "x86"
1272 )))]
limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb])1273 fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
1274     debug_assert_eq!(r.len(), 2 * a.len());
1275     debug_assert_eq!(a.len(), b.len());
1276     let ab_len = a.len();
1277 
1278     crate::polyfill::slice::fill(&mut r[..ab_len], 0);
1279     for (i, &b_limb) in b.iter().enumerate() {
1280         r[ab_len + i] = unsafe {
1281             GFp_limbs_mul_add_limb(
1282                 (&mut r[i..][..ab_len]).as_mut_ptr(),
1283                 a.as_ptr(),
1284                 b_limb,
1285                 ab_len,
1286             )
1287         };
1288     }
1289 }
1290 
1291 /// r = a * b
1292 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0)1293 fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1294     debug_assert_eq!(r.len(), m.len());
1295     debug_assert_eq!(a.len(), m.len());
1296     debug_assert_eq!(b.len(), m.len());
1297 
1298     #[cfg(any(
1299         target_arch = "aarch64",
1300         target_arch = "arm",
1301         target_arch = "x86_64",
1302         target_arch = "x86"
1303     ))]
1304     unsafe {
1305         GFp_bn_mul_mont(
1306             r.as_mut_ptr(),
1307             a.as_ptr(),
1308             b.as_ptr(),
1309             m.as_ptr(),
1310             n0,
1311             r.len(),
1312         )
1313     }
1314 
1315     #[cfg(not(any(
1316         target_arch = "aarch64",
1317         target_arch = "arm",
1318         target_arch = "x86_64",
1319         target_arch = "x86"
1320     )))]
1321     {
1322         let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1323         let tmp = &mut tmp[..(2 * a.len())];
1324         limbs_mul(tmp, a, b);
1325         limbs_from_mont_in_place(r, tmp, m, n0)
1326     }
1327 }
1328 
1329 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0)1330 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1331     debug_assert_eq!(r.len(), m.len());
1332     #[cfg(any(
1333         target_arch = "aarch64",
1334         target_arch = "arm",
1335         target_arch = "x86_64",
1336         target_arch = "x86"
1337     ))]
1338     unsafe {
1339         GFp_bn_mul_mont(
1340             r.as_mut_ptr(),
1341             r.as_ptr(),
1342             r.as_ptr(),
1343             m.as_ptr(),
1344             n0,
1345             r.len(),
1346         )
1347     }
1348 
1349     #[cfg(not(any(
1350         target_arch = "aarch64",
1351         target_arch = "arm",
1352         target_arch = "x86_64",
1353         target_arch = "x86"
1354     )))]
1355     {
1356         let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1357         let tmp = &mut tmp[..(2 * r.len())];
1358         limbs_mul(tmp, r, r);
1359         limbs_from_mont_in_place(r, tmp, m, n0)
1360     }
1361 }
1362 
1363 extern "C" {
1364     #[cfg(any(
1365         target_arch = "aarch64",
1366         target_arch = "arm",
1367         target_arch = "x86_64",
1368         target_arch = "x86"
1369     ))]
1370     // `r` and/or 'a' and/or 'b' may alias.
GFp_bn_mul_mont( r: *mut Limb, a: *const Limb, b: *const Limb, n: *const Limb, n0: &N0, num_limbs: c::size_t, )1371     fn GFp_bn_mul_mont(
1372         r: *mut Limb,
1373         a: *const Limb,
1374         b: *const Limb,
1375         n: *const Limb,
1376         n0: &N0,
1377         num_limbs: c::size_t,
1378     );
1379 
1380     // `r` must not alias `a`
1381     #[cfg(any(
1382         test,
1383         not(any(
1384             target_arch = "aarch64",
1385             target_arch = "arm",
1386             target_arch = "x86_64",
1387             target_arch = "x86"
1388         ))
1389     ))]
1390     #[must_use]
GFp_limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb1391     fn GFp_limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
1392 }
1393 
1394 #[cfg(test)]
1395 mod tests {
1396     use super::*;
1397     use crate::test;
1398     use alloc::format;
1399 
1400     // Type-level representation of an arbitrary modulus.
1401     struct M {}
1402 
1403     unsafe impl PublicModulus for M {}
1404 
1405     #[test]
test_elem_exp_consttime()1406     fn test_elem_exp_consttime() {
1407         test::run(
1408             test_file!("bigint_elem_exp_consttime_tests.txt"),
1409             |section, test_case| {
1410                 assert_eq!(section, "");
1411 
1412                 let m = consume_modulus::<M>(test_case, "M");
1413                 let expected_result = consume_elem(test_case, "ModExp", &m);
1414                 let base = consume_elem(test_case, "A", &m);
1415                 let e = {
1416                     let bytes = test_case.consume_bytes("E");
1417                     PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1418                         .expect("valid exponent")
1419                 };
1420                 let base = into_encoded(base, &m);
1421                 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1422                 assert_elem_eq(&actual_result, &expected_result);
1423 
1424                 Ok(())
1425             },
1426         )
1427     }
1428 
1429     // TODO: fn test_elem_exp_vartime() using
1430     // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
1431     // In the meantime, the function is tested indirectly via the RSA
1432     // verification and signing tests.
1433     #[test]
test_elem_mul()1434     fn test_elem_mul() {
1435         test::run(
1436             test_file!("bigint_elem_mul_tests.txt"),
1437             |section, test_case| {
1438                 assert_eq!(section, "");
1439 
1440                 let m = consume_modulus::<M>(test_case, "M");
1441                 let expected_result = consume_elem(test_case, "ModMul", &m);
1442                 let a = consume_elem(test_case, "A", &m);
1443                 let b = consume_elem(test_case, "B", &m);
1444 
1445                 let b = into_encoded(b, &m);
1446                 let a = into_encoded(a, &m);
1447                 let actual_result = elem_mul(&a, b, &m);
1448                 let actual_result = actual_result.into_unencoded(&m);
1449                 assert_elem_eq(&actual_result, &expected_result);
1450 
1451                 Ok(())
1452             },
1453         )
1454     }
1455 
1456     #[test]
test_elem_squared()1457     fn test_elem_squared() {
1458         test::run(
1459             test_file!("bigint_elem_squared_tests.txt"),
1460             |section, test_case| {
1461                 assert_eq!(section, "");
1462 
1463                 let m = consume_modulus::<M>(test_case, "M");
1464                 let expected_result = consume_elem(test_case, "ModSquare", &m);
1465                 let a = consume_elem(test_case, "A", &m);
1466 
1467                 let a = into_encoded(a, &m);
1468                 let actual_result = elem_squared(a, &m.as_partial());
1469                 let actual_result = actual_result.into_unencoded(&m);
1470                 assert_elem_eq(&actual_result, &expected_result);
1471 
1472                 Ok(())
1473             },
1474         )
1475     }
1476 
1477     #[test]
test_elem_reduced()1478     fn test_elem_reduced() {
1479         test::run(
1480             test_file!("bigint_elem_reduced_tests.txt"),
1481             |section, test_case| {
1482                 assert_eq!(section, "");
1483 
1484                 struct MM {}
1485                 unsafe impl SmallerModulus<MM> for M {}
1486                 unsafe impl NotMuchSmallerModulus<MM> for M {}
1487 
1488                 let m = consume_modulus::<M>(test_case, "M");
1489                 let expected_result = consume_elem(test_case, "R", &m);
1490                 let a =
1491                     consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1492 
1493                 let actual_result = elem_reduced(&a, &m);
1494                 let oneRR = m.oneRR();
1495                 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1496                 assert_elem_eq(&actual_result, &expected_result);
1497 
1498                 Ok(())
1499             },
1500         )
1501     }
1502 
1503     #[test]
test_elem_reduced_once()1504     fn test_elem_reduced_once() {
1505         test::run(
1506             test_file!("bigint_elem_reduced_once_tests.txt"),
1507             |section, test_case| {
1508                 assert_eq!(section, "");
1509 
1510                 struct N {}
1511                 struct QQ {}
1512                 unsafe impl SmallerModulus<N> for QQ {}
1513                 unsafe impl SlightlySmallerModulus<N> for QQ {}
1514 
1515                 let qq = consume_modulus::<QQ>(test_case, "QQ");
1516                 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1517                 let n = consume_modulus::<N>(test_case, "N");
1518                 let a = consume_elem::<N>(test_case, "A", &n);
1519 
1520                 let actual_result = elem_reduced_once(&a, &qq);
1521                 assert_elem_eq(&actual_result, &expected_result);
1522 
1523                 Ok(())
1524             },
1525         )
1526     }
1527 
1528     #[test]
test_modulus_debug()1529     fn test_modulus_debug() {
1530         let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(
1531             &[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS],
1532         ))
1533         .unwrap();
1534         assert_eq!("Modulus", format!("{:?}", modulus));
1535     }
1536 
1537     #[test]
test_public_exponent_debug()1538     fn test_public_exponent_debug() {
1539         let exponent =
1540             PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1541                 .unwrap();
1542         assert_eq!("PublicExponent(65537)", format!("{:?}", exponent));
1543     }
1544 
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>1545     fn consume_elem<M>(
1546         test_case: &mut test::TestCase,
1547         name: &str,
1548         m: &Modulus<M>,
1549     ) -> Elem<M, Unencoded> {
1550         let value = test_case.consume_bytes(name);
1551         Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1552     }
1553 
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>1554     fn consume_elem_unchecked<M>(
1555         test_case: &mut test::TestCase,
1556         name: &str,
1557         num_limbs: usize,
1558     ) -> Elem<M, Unencoded> {
1559         let value = consume_nonnegative(test_case, name);
1560         let mut limbs = BoxedLimbs::zero(Width {
1561             num_limbs,
1562             m: PhantomData,
1563         });
1564         limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1565         Elem {
1566             limbs,
1567             encoding: PhantomData,
1568         }
1569     }
1570 
consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M>1571     fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1572         let value = test_case.consume_bytes(name);
1573         let (value, _) =
1574             Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1575         value
1576     }
1577 
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1578     fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1579         let bytes = test_case.consume_bytes(name);
1580         let (r, _r_bits) =
1581             Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1582         r
1583     }
1584 
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1585     fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1586         if elem_verify_equal_consttime(&a, b).is_err() {
1587             panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1588         }
1589     }
1590 
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1591     fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1592         elem_mul(m.oneRR().as_ref(), a, m)
1593     }
1594 
1595     #[test]
1596     // TODO: wasm
test_mul_add_words()1597     fn test_mul_add_words() {
1598         const ZERO: Limb = 0;
1599         const MAX: Limb = ZERO.wrapping_sub(1);
1600         static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
1601             (&[0], &[0], 0, 0, &[0]),
1602             (&[MAX], &[0], MAX, 0, &[MAX]),
1603             (&[0], &[MAX], MAX, MAX - 1, &[1]),
1604             (&[MAX], &[MAX], MAX, MAX, &[0]),
1605             (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
1606             (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
1607             (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
1608             (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
1609             (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
1610         ];
1611 
1612         for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
1613             extern crate std;
1614             let mut r = std::vec::Vec::from(*r_input);
1615             assert_eq!(r.len(), a.len()); // Sanity check
1616             let actual_retval =
1617                 unsafe { GFp_limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
1618             assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
1619             assert_eq!(
1620                 actual_retval, *expected_retval,
1621                 "{}: {:x?} != {:x?}",
1622                 i, actual_retval, *expected_retval
1623             );
1624         }
1625     }
1626 }
1627