1 /// Fused multiply-add. Computes `(self * a) + b` with only one rounding
2 /// error, yielding a more accurate result than an unfused multiply-add.
3 ///
4 /// Using `mul_add` can be more performant than an unfused multiply-add if
5 /// the target architecture has a dedicated `fma` CPU instruction.
6 ///
7 /// Note that `A` and `B` are `Self` by default, but this is not mandatory.
8 ///
9 /// # Example
10 ///
11 /// ```
12 /// use std::f32;
13 ///
14 /// let m = 10.0_f32;
15 /// let x = 4.0_f32;
16 /// let b = 60.0_f32;
17 ///
18 /// // 100.0
19 /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
20 ///
21 /// assert!(abs_difference <= 100.0 * f32::EPSILON);
22 /// ```
23 pub trait MulAdd<A = Self, B = Self> {
24     /// The resulting type after applying the fused multiply-add.
25     type Output;
26 
27     /// Performs the fused multiply-add operation.
mul_add(self, a: A, b: B) -> Self::Output28     fn mul_add(self, a: A, b: B) -> Self::Output;
29 }
30 
31 /// The fused multiply-add assignment operation.
32 pub trait MulAddAssign<A = Self, B = Self> {
33     /// Performs the fused multiply-add operation.
mul_add_assign(&mut self, a: A, b: B)34     fn mul_add_assign(&mut self, a: A, b: B);
35 }
36 
37 #[cfg(any(feature = "std", feature = "libm"))]
38 impl MulAdd<f32, f32> for f32 {
39     type Output = Self;
40 
41     #[inline]
mul_add(self, a: Self, b: Self) -> Self::Output42     fn mul_add(self, a: Self, b: Self) -> Self::Output {
43         <Self as ::Float>::mul_add(self, a, b)
44     }
45 }
46 
47 #[cfg(any(feature = "std", feature = "libm"))]
48 impl MulAdd<f64, f64> for f64 {
49     type Output = Self;
50 
51     #[inline]
mul_add(self, a: Self, b: Self) -> Self::Output52     fn mul_add(self, a: Self, b: Self) -> Self::Output {
53         <Self as ::Float>::mul_add(self, a, b)
54     }
55 }
56 
57 macro_rules! mul_add_impl {
58     ($trait_name:ident for $($t:ty)*) => {$(
59         impl $trait_name for $t {
60             type Output = Self;
61 
62             #[inline]
63             fn mul_add(self, a: Self, b: Self) -> Self::Output {
64                 (self * a) + b
65             }
66         }
67     )*}
68 }
69 
70 mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
71 #[cfg(has_i128)]
72 mul_add_impl!(MulAdd for i128 u128);
73 
74 #[cfg(any(feature = "std", feature = "libm"))]
75 impl MulAddAssign<f32, f32> for f32 {
76     #[inline]
mul_add_assign(&mut self, a: Self, b: Self)77     fn mul_add_assign(&mut self, a: Self, b: Self) {
78         *self = <Self as ::Float>::mul_add(*self, a, b)
79     }
80 }
81 
82 #[cfg(any(feature = "std", feature = "libm"))]
83 impl MulAddAssign<f64, f64> for f64 {
84     #[inline]
mul_add_assign(&mut self, a: Self, b: Self)85     fn mul_add_assign(&mut self, a: Self, b: Self) {
86         *self = <Self as ::Float>::mul_add(*self, a, b)
87     }
88 }
89 
90 macro_rules! mul_add_assign_impl {
91     ($trait_name:ident for $($t:ty)*) => {$(
92         impl $trait_name for $t {
93             #[inline]
94             fn mul_add_assign(&mut self, a: Self, b: Self) {
95                 *self = (*self * a) + b
96             }
97         }
98     )*}
99 }
100 
101 mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
102 #[cfg(has_i128)]
103 mul_add_assign_impl!(MulAddAssign for i128 u128);
104 
105 #[cfg(test)]
106 mod tests {
107     use super::*;
108 
109     #[test]
mul_add_integer()110     fn mul_add_integer() {
111         macro_rules! test_mul_add {
112             ($($t:ident)+) => {
113                 $(
114                     {
115                         let m: $t = 2;
116                         let x: $t = 3;
117                         let b: $t = 4;
118 
119                         assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
120                     }
121                 )+
122             };
123         }
124 
125         test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
126     }
127 
128     #[test]
129     #[cfg(feature = "std")]
mul_add_float()130     fn mul_add_float() {
131         macro_rules! test_mul_add {
132             ($($t:ident)+) => {
133                 $(
134                     {
135                         use core::$t;
136 
137                         let m: $t = 12.0;
138                         let x: $t = 3.4;
139                         let b: $t = 5.6;
140 
141                         let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
142 
143                         assert!(abs_difference <= 46.4 * $t::EPSILON);
144                     }
145                 )+
146             };
147         }
148 
149         test_mul_add!(f32 f64);
150     }
151 }
152