1 // Copyright (C) 2020, Cloudflare, Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 //     * Redistributions of source code must retain the above copyright notice,
9 //       this list of conditions and the following disclaimer.
10 //
11 //     * Redistributions in binary form must reproduce the above copyright
12 //       notice, this list of conditions and the following disclaimer in the
13 //       documentation and/or other materials provided with the distribution.
14 //
15 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16 // IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17 // THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27 //! HyStart++
28 //!
29 //! This implementation is based on the following I-D:
30 //!
31 //! https://tools.ietf.org/html/draft-balasubramanian-tcpm-hystartplusplus-02
32 
33 use std::cmp;
34 use std::time::Duration;
35 use std::time::Instant;
36 
37 use crate::recovery;
38 
39 /// Constants from I-D.
40 const LOW_CWND: usize = 16;
41 
42 const MIN_RTT_THRESH: Duration = Duration::from_millis(4);
43 
44 const MAX_RTT_THRESH: Duration = Duration::from_millis(16);
45 
46 pub const LSS_DIVISOR: f64 = 0.25;
47 
48 pub const N_RTT_SAMPLE: usize = 8;
49 
50 #[derive(Default)]
51 pub struct Hystart {
52     enabled: bool,
53 
54     window_end: Option<u64>,
55 
56     last_round_min_rtt: Option<Duration>,
57 
58     current_round_min_rtt: Option<Duration>,
59 
60     rtt_sample_count: usize,
61 
62     lss_start_time: Option<Instant>,
63 }
64 
65 impl std::fmt::Debug for Hystart {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result66     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
67         write!(f, "window_end={:?} ", self.window_end)?;
68         write!(f, "last_round_min_rtt={:?} ", self.last_round_min_rtt)?;
69         write!(f, "current_round_min_rtt={:?} ", self.current_round_min_rtt)?;
70         write!(f, "rtt_sample_count={:?} ", self.rtt_sample_count)?;
71         write!(f, "lss_start_time={:?} ", self.lss_start_time)?;
72 
73         Ok(())
74     }
75 }
76 
77 impl Hystart {
new(enabled: bool) -> Self78     pub fn new(enabled: bool) -> Self {
79         Self {
80             enabled,
81 
82             ..Default::default()
83         }
84     }
85 
enabled(&self) -> bool86     pub fn enabled(&self) -> bool {
87         self.enabled
88     }
89 
lss_start_time(&self) -> Option<Instant>90     pub fn lss_start_time(&self) -> Option<Instant> {
91         self.lss_start_time
92     }
93 
start_round(&mut self, pkt_num: u64)94     pub fn start_round(&mut self, pkt_num: u64) {
95         if self.window_end.is_none() {
96             *self = Hystart {
97                 enabled: self.enabled,
98 
99                 window_end: Some(pkt_num),
100 
101                 last_round_min_rtt: self.current_round_min_rtt,
102 
103                 current_round_min_rtt: None,
104 
105                 rtt_sample_count: 0,
106 
107                 lss_start_time: None,
108             };
109         }
110     }
111 
112     // Returns a new (ssthresh, cwnd) during slow start.
on_packet_acked( &mut self, packet: &recovery::Acked, rtt: Duration, cwnd: usize, ssthresh: usize, now: Instant, ) -> (usize, usize)113     pub fn on_packet_acked(
114         &mut self, packet: &recovery::Acked, rtt: Duration, cwnd: usize,
115         ssthresh: usize, now: Instant,
116     ) -> (usize, usize) {
117         let mut ssthresh = ssthresh;
118         let mut cwnd = cwnd;
119 
120         if self.lss_start_time().is_none() {
121             // Reno Slow Start.
122             cwnd += packet.size;
123 
124             if let Some(current_round_min_rtt) = self.current_round_min_rtt {
125                 self.current_round_min_rtt =
126                     Some(cmp::min(current_round_min_rtt, rtt));
127             } else {
128                 self.current_round_min_rtt = Some(rtt);
129             }
130 
131             self.rtt_sample_count += 1;
132 
133             if cwnd >= (LOW_CWND * recovery::MAX_DATAGRAM_SIZE) &&
134                 self.rtt_sample_count >= N_RTT_SAMPLE &&
135                 self.current_round_min_rtt.is_some() &&
136                 self.last_round_min_rtt.is_some()
137             {
138                 // clamp(min_rtt_thresh, last_round_min_rtt/8,
139                 // max_rtt_thresh)
140                 let rtt_thresh = cmp::max(
141                     self.last_round_min_rtt.unwrap() / 8,
142                     MIN_RTT_THRESH,
143                 );
144                 let rtt_thresh = cmp::min(rtt_thresh, MAX_RTT_THRESH);
145 
146                 // Check if we can exit to LSS.
147                 if self.current_round_min_rtt.unwrap() >=
148                     (self.last_round_min_rtt.unwrap() + rtt_thresh)
149                 {
150                     ssthresh = cwnd;
151 
152                     self.lss_start_time = Some(now);
153                 }
154             }
155 
156             // Check if we reached the end of the round.
157             if let Some(end_pkt_num) = self.window_end {
158                 if packet.pkt_num >= end_pkt_num {
159                     // Start of a new round.
160                     self.window_end = None;
161                 }
162             }
163         } else {
164             // LSS (Limited Slow Start).
165             let k = cwnd as f64 / (LSS_DIVISOR * ssthresh as f64);
166 
167             cwnd += (packet.size as f64 / k) as usize;
168         }
169 
170         (cwnd, ssthresh)
171     }
172 
173     // Exit HyStart++ when entering congestion avoidance.
congestion_event(&mut self)174     pub fn congestion_event(&mut self) {
175         if self.window_end.is_some() {
176             self.window_end = None;
177 
178             self.lss_start_time = None;
179         }
180     }
181 }
182 
183 #[cfg(test)]
184 mod tests {
185     use super::*;
186 
187     #[test]
start_round()188     fn start_round() {
189         let mut hspp = Hystart::default();
190         let pkt_num = 100;
191 
192         hspp.start_round(pkt_num);
193 
194         assert_eq!(hspp.window_end, Some(pkt_num));
195         assert_eq!(hspp.current_round_min_rtt, None);
196     }
197 
198     #[test]
reno_slow_start()199     fn reno_slow_start() {
200         let mut hspp = Hystart::default();
201         let pkt_num = 100;
202         let size = 1000;
203         let now = Instant::now();
204 
205         hspp.start_round(pkt_num);
206 
207         assert_eq!(hspp.window_end, Some(pkt_num));
208 
209         let p = recovery::Acked {
210             pkt_num,
211             time_sent: now + Duration::from_millis(10),
212             size,
213         };
214 
215         let init_cwnd = 30000;
216         let init_ssthresh = 1000000;
217 
218         let (cwnd, ssthresh) = hspp.on_packet_acked(
219             &p,
220             Duration::from_millis(10),
221             init_cwnd,
222             init_ssthresh,
223             now,
224         );
225 
226         // Expecting Reno slow start.
227         assert_eq!(hspp.lss_start_time().is_some(), false);
228         assert_eq!((cwnd, ssthresh), (init_cwnd + size, init_ssthresh));
229     }
230 
231     #[test]
limited_slow_start()232     fn limited_slow_start() {
233         let mut hspp = Hystart::default();
234         let size = 1000;
235         let now = Instant::now();
236 
237         // 1st round rtt = 50ms
238         let rtt_1st = 50;
239 
240         // end of 1st round
241         let pkt_1st = N_RTT_SAMPLE as u64;
242 
243         hspp.start_round(pkt_1st);
244 
245         assert_eq!(hspp.window_end, Some(pkt_1st));
246 
247         let (mut cwnd, mut ssthresh) = (30000, 1000000);
248         let mut pkt_num = 0;
249 
250         // 1st round.
251         for _ in 0..N_RTT_SAMPLE + 1 {
252             let p = recovery::Acked {
253                 pkt_num,
254                 time_sent: now + Duration::from_millis(pkt_num),
255                 size,
256             };
257 
258             // We use a fixed rtt for 1st round.
259             let rtt = Duration::from_millis(rtt_1st);
260 
261             let (new_cwnd, new_ssthresh) =
262                 hspp.on_packet_acked(&p, rtt, cwnd, ssthresh, now);
263 
264             cwnd = new_cwnd;
265             ssthresh = new_ssthresh;
266 
267             pkt_num += 1;
268         }
269 
270         // 2nd round. rtt = 100ms to trigger LSS.
271         let rtt_2nd = 100;
272 
273         hspp.start_round(pkt_1st * 2 + 1);
274 
275         for _ in 0..N_RTT_SAMPLE + 1 {
276             let p = recovery::Acked {
277                 pkt_num,
278                 time_sent: now + Duration::from_millis(pkt_num),
279                 size,
280             };
281 
282             // Keep increasing rtt to simulate buffer queueing delay
283             // This is to exit from slow slart to LSS.
284             let rtt = Duration::from_millis(rtt_2nd + pkt_num * 4);
285 
286             let (new_cwnd, new_ssthresh) =
287                 hspp.on_packet_acked(&p, rtt, cwnd, ssthresh, now);
288 
289             cwnd = new_cwnd;
290             ssthresh = new_ssthresh;
291 
292             pkt_num += 1;
293         }
294 
295         // At this point, cwnd exits to LSS mode.
296         assert_eq!(hspp.lss_start_time().is_some(), true);
297 
298         // Check if current cwnd is in LSS.
299         let cur_ssthresh = 47000;
300         let k = cur_ssthresh as f64 / (LSS_DIVISOR * cur_ssthresh as f64);
301         let lss_cwnd = cur_ssthresh as f64 + size as f64 / k;
302 
303         assert_eq!((cwnd, ssthresh), (lss_cwnd as usize, cur_ssthresh));
304     }
305 
306     #[test]
congestion_event()307     fn congestion_event() {
308         let mut hspp = Hystart::default();
309         let pkt_num = 100;
310 
311         hspp.start_round(pkt_num);
312 
313         assert_eq!(hspp.window_end, Some(pkt_num));
314 
315         // When moving into CA mode, window_end should be cleared.
316         hspp.congestion_event();
317 
318         assert_eq!(hspp.window_end, None);
319     }
320 }
321