1 use core;
2 use core::mem;
3 use traits::checked_pow;
4 use traits::PrimInt;
5 use Integer;
6 
7 /// Provides methods to compute an integer's square root, cube root,
8 /// and arbitrary `n`th root.
9 pub trait Roots: Integer {
10     /// Returns the truncated principal `n`th root of an integer
11     /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }`
12     ///
13     /// This is solving for `r` in `rⁿ = x`, rounding toward zero.
14     /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`.
15     /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`.
16     ///
17     /// # Panics
18     ///
19     /// Panics if `n` is zero:
20     ///
21     /// ```should_panic
22     /// # use num_integer::Roots;
23     /// println!("can't compute ⁰√x : {}", 123.nth_root(0));
24     /// ```
25     ///
26     /// or if `n` is even and `self` is negative:
27     ///
28     /// ```should_panic
29     /// # use num_integer::Roots;
30     /// println!("no imaginary numbers... {}", (-1).nth_root(10));
31     /// ```
32     ///
33     /// # Examples
34     ///
35     /// ```
36     /// use num_integer::Roots;
37     ///
38     /// let x: i32 = 12345;
39     /// assert_eq!(x.nth_root(1), x);
40     /// assert_eq!(x.nth_root(2), x.sqrt());
41     /// assert_eq!(x.nth_root(3), x.cbrt());
42     /// assert_eq!(x.nth_root(4), 10);
43     /// assert_eq!(x.nth_root(13), 2);
44     /// assert_eq!(x.nth_root(14), 1);
45     /// assert_eq!(x.nth_root(std::u32::MAX), 1);
46     ///
47     /// assert_eq!(std::i32::MAX.nth_root(30), 2);
48     /// assert_eq!(std::i32::MAX.nth_root(31), 1);
49     /// assert_eq!(std::i32::MIN.nth_root(31), -2);
50     /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1);
51     ///
52     /// assert_eq!(std::u32::MAX.nth_root(31), 2);
53     /// assert_eq!(std::u32::MAX.nth_root(32), 1);
54     /// ```
nth_root(&self, n: u32) -> Self55     fn nth_root(&self, n: u32) -> Self;
56 
57     /// Returns the truncated principal square root of an integer -- `⌊√x⌋`
58     ///
59     /// This is solving for `r` in `r² = x`, rounding toward zero.
60     /// The result will satisfy `r² ≤ x < (r+1)²`.
61     ///
62     /// # Panics
63     ///
64     /// Panics if `self` is less than zero:
65     ///
66     /// ```should_panic
67     /// # use num_integer::Roots;
68     /// println!("no imaginary numbers... {}", (-1).sqrt());
69     /// ```
70     ///
71     /// # Examples
72     ///
73     /// ```
74     /// use num_integer::Roots;
75     ///
76     /// let x: i32 = 12345;
77     /// assert_eq!((x * x).sqrt(), x);
78     /// assert_eq!((x * x + 1).sqrt(), x);
79     /// assert_eq!((x * x - 1).sqrt(), x - 1);
80     /// ```
81     #[inline]
sqrt(&self) -> Self82     fn sqrt(&self) -> Self {
83         self.nth_root(2)
84     }
85 
86     /// Returns the truncated principal cube root of an integer --
87     /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }`
88     ///
89     /// This is solving for `r` in `r³ = x`, rounding toward zero.
90     /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`.
91     /// If `x` is negative, then `(r-1)³ < x ≤ r³`.
92     ///
93     /// # Examples
94     ///
95     /// ```
96     /// use num_integer::Roots;
97     ///
98     /// let x: i32 = 1234;
99     /// assert_eq!((x * x * x).cbrt(), x);
100     /// assert_eq!((x * x * x + 1).cbrt(), x);
101     /// assert_eq!((x * x * x - 1).cbrt(), x - 1);
102     ///
103     /// assert_eq!((-(x * x * x)).cbrt(), -x);
104     /// assert_eq!((-(x * x * x + 1)).cbrt(), -x);
105     /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1));
106     /// ```
107     #[inline]
cbrt(&self) -> Self108     fn cbrt(&self) -> Self {
109         self.nth_root(3)
110     }
111 }
112 
113 /// Returns the truncated principal square root of an integer --
114 /// see [Roots::sqrt](trait.Roots.html#method.sqrt).
115 #[inline]
sqrt<T: Roots>(x: T) -> T116 pub fn sqrt<T: Roots>(x: T) -> T {
117     x.sqrt()
118 }
119 
120 /// Returns the truncated principal cube root of an integer --
121 /// see [Roots::cbrt](trait.Roots.html#method.cbrt).
122 #[inline]
cbrt<T: Roots>(x: T) -> T123 pub fn cbrt<T: Roots>(x: T) -> T {
124     x.cbrt()
125 }
126 
127 /// Returns the truncated principal `n`th root of an integer --
128 /// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root).
129 #[inline]
nth_root<T: Roots>(x: T, n: u32) -> T130 pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
131     x.nth_root(n)
132 }
133 
134 macro_rules! signed_roots {
135     ($T:ty, $U:ty) => {
136         impl Roots for $T {
137             #[inline]
138             fn nth_root(&self, n: u32) -> Self {
139                 if *self >= 0 {
140                     (*self as $U).nth_root(n) as Self
141                 } else {
142                     assert!(n.is_odd(), "even roots of a negative are imaginary");
143                     -((self.wrapping_neg() as $U).nth_root(n) as Self)
144                 }
145             }
146 
147             #[inline]
148             fn sqrt(&self) -> Self {
149                 assert!(*self >= 0, "the square root of a negative is imaginary");
150                 (*self as $U).sqrt() as Self
151             }
152 
153             #[inline]
154             fn cbrt(&self) -> Self {
155                 if *self >= 0 {
156                     (*self as $U).cbrt() as Self
157                 } else {
158                     -((self.wrapping_neg() as $U).cbrt() as Self)
159                 }
160             }
161         }
162     };
163 }
164 
165 signed_roots!(i8, u8);
166 signed_roots!(i16, u16);
167 signed_roots!(i32, u32);
168 signed_roots!(i64, u64);
169 #[cfg(has_i128)]
170 signed_roots!(i128, u128);
171 signed_roots!(isize, usize);
172 
173 #[inline]
fixpoint<T, F>(mut x: T, f: F) -> T where T: Integer + Copy, F: Fn(T) -> T,174 fn fixpoint<T, F>(mut x: T, f: F) -> T
175 where
176     T: Integer + Copy,
177     F: Fn(T) -> T,
178 {
179     let mut xn = f(x);
180     while x < xn {
181         x = xn;
182         xn = f(x);
183     }
184     while x > xn {
185         x = xn;
186         xn = f(x);
187     }
188     x
189 }
190 
191 #[inline]
bits<T>() -> u32192 fn bits<T>() -> u32 {
193     8 * mem::size_of::<T>() as u32
194 }
195 
196 #[inline]
log2<T: PrimInt>(x: T) -> u32197 fn log2<T: PrimInt>(x: T) -> u32 {
198     debug_assert!(x > T::zero());
199     bits::<T>() - 1 - x.leading_zeros()
200 }
201 
202 macro_rules! unsigned_roots {
203     ($T:ident) => {
204         impl Roots for $T {
205             #[inline]
206             fn nth_root(&self, n: u32) -> Self {
207                 fn go(a: $T, n: u32) -> $T {
208                     // Specialize small roots
209                     match n {
210                         0 => panic!("can't find a root of degree 0!"),
211                         1 => return a,
212                         2 => return a.sqrt(),
213                         3 => return a.cbrt(),
214                         _ => (),
215                     }
216 
217                     // The root of values less than 2ⁿ can only be 0 or 1.
218                     if bits::<$T>() <= n || a < (1 << n) {
219                         return (a > 0) as $T;
220                     }
221 
222                     if bits::<$T>() > 64 {
223                         // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
224                         return if a <= core::u64::MAX as $T {
225                             (a as u64).nth_root(n) as $T
226                         } else {
227                             let lo = (a >> n).nth_root(n) << 1;
228                             let hi = lo + 1;
229                             // 128-bit `checked_mul` also involves division, but we can't always
230                             // compute `hiⁿ` without risking overflow.  Try to avoid it though...
231                             if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
232                                 match checked_pow(hi, n as usize) {
233                                     Some(x) if x <= a => hi,
234                                     _ => lo,
235                                 }
236                             } else {
237                                 if hi.pow(n) <= a {
238                                     hi
239                                 } else {
240                                     lo
241                                 }
242                             }
243                         };
244                     }
245 
246                     #[cfg(feature = "std")]
247                     #[inline]
248                     fn guess(x: $T, n: u32) -> $T {
249                         // for smaller inputs, `f64` doesn't justify its cost.
250                         if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
251                             1 << ((log2(x) + n - 1) / n)
252                         } else {
253                             ((x as f64).ln() / f64::from(n)).exp() as $T
254                         }
255                     }
256 
257                     #[cfg(not(feature = "std"))]
258                     #[inline]
259                     fn guess(x: $T, n: u32) -> $T {
260                         1 << ((log2(x) + n - 1) / n)
261                     }
262 
263                     // https://en.wikipedia.org/wiki/Nth_root_algorithm
264                     let n1 = n - 1;
265                     let next = |x: $T| {
266                         let y = match checked_pow(x, n1 as usize) {
267                             Some(ax) => a / ax,
268                             None => 0,
269                         };
270                         (y + x * n1 as $T) / n as $T
271                     };
272                     fixpoint(guess(a, n), next)
273                 }
274                 go(*self, n)
275             }
276 
277             #[inline]
278             fn sqrt(&self) -> Self {
279                 fn go(a: $T) -> $T {
280                     if bits::<$T>() > 64 {
281                         // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
282                         return if a <= core::u64::MAX as $T {
283                             (a as u64).sqrt() as $T
284                         } else {
285                             let lo = (a >> 2u32).sqrt() << 1;
286                             let hi = lo + 1;
287                             if hi * hi <= a {
288                                 hi
289                             } else {
290                                 lo
291                             }
292                         };
293                     }
294 
295                     if a < 4 {
296                         return (a > 0) as $T;
297                     }
298 
299                     #[cfg(feature = "std")]
300                     #[inline]
301                     fn guess(x: $T) -> $T {
302                         (x as f64).sqrt() as $T
303                     }
304 
305                     #[cfg(not(feature = "std"))]
306                     #[inline]
307                     fn guess(x: $T) -> $T {
308                         1 << ((log2(x) + 1) / 2)
309                     }
310 
311                     // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
312                     let next = |x: $T| (a / x + x) >> 1;
313                     fixpoint(guess(a), next)
314                 }
315                 go(*self)
316             }
317 
318             #[inline]
319             fn cbrt(&self) -> Self {
320                 fn go(a: $T) -> $T {
321                     if bits::<$T>() > 64 {
322                         // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
323                         return if a <= core::u64::MAX as $T {
324                             (a as u64).cbrt() as $T
325                         } else {
326                             let lo = (a >> 3u32).cbrt() << 1;
327                             let hi = lo + 1;
328                             if hi * hi * hi <= a {
329                                 hi
330                             } else {
331                                 lo
332                             }
333                         };
334                     }
335 
336                     if bits::<$T>() <= 32 {
337                         // Implementation based on Hacker's Delight `icbrt2`
338                         let mut x = a;
339                         let mut y2 = 0;
340                         let mut y = 0;
341                         let smax = bits::<$T>() / 3;
342                         for s in (0..smax + 1).rev() {
343                             let s = s * 3;
344                             y2 *= 4;
345                             y *= 2;
346                             let b = 3 * (y2 + y) + 1;
347                             if x >> s >= b {
348                                 x -= b << s;
349                                 y2 += 2 * y + 1;
350                                 y += 1;
351                             }
352                         }
353                         return y;
354                     }
355 
356                     if a < 8 {
357                         return (a > 0) as $T;
358                     }
359                     if a <= core::u32::MAX as $T {
360                         return (a as u32).cbrt() as $T;
361                     }
362 
363                     #[cfg(feature = "std")]
364                     #[inline]
365                     fn guess(x: $T) -> $T {
366                         (x as f64).cbrt() as $T
367                     }
368 
369                     #[cfg(not(feature = "std"))]
370                     #[inline]
371                     fn guess(x: $T) -> $T {
372                         1 << ((log2(x) + 2) / 3)
373                     }
374 
375                     // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
376                     let next = |x: $T| (a / (x * x) + x * 2) / 3;
377                     fixpoint(guess(a), next)
378                 }
379                 go(*self)
380             }
381         }
382     };
383 }
384 
385 unsigned_roots!(u8);
386 unsigned_roots!(u16);
387 unsigned_roots!(u32);
388 unsigned_roots!(u64);
389 #[cfg(has_i128)]
390 unsigned_roots!(u128);
391 unsigned_roots!(usize);
392