1#![allow(dead_code)]
9
10const MSS: u32 = 1460;
12
13const RTO_MIN_US: u64 = 1_000_000;
15
16const RTO_MAX_US: u64 = 60_000_000;
18
19const RTO_INITIAL_US: u64 = 1_000_000;
21
22const CLOCK_GRANULARITY_US: u64 = 1_000;
24
25const DUP_ACK_THRESHOLD: u32 = 3;
27
28const SRTT_SHIFT: u32 = 3; const RTTVAR_SHIFT: u32 = 2; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum CongestionPhase {
36 SlowStart,
38 CongestionAvoidance,
40 FastRecovery,
42}
43
44pub trait CongestionController: Send + Sync {
49 fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64);
54
55 fn on_duplicate_ack(&mut self);
57
58 fn on_timeout(&mut self);
60
61 fn congestion_window(&self) -> u32;
63
64 fn slow_start_threshold(&self) -> u32;
66}
67
68#[derive(Debug, Clone)]
70pub struct CongestionState {
71 pub cwnd: u32,
73 pub ssthresh: u32,
75 rtt_estimate_shifted: u64,
78 rtt_variance_shifted: u64,
80 pub rto: u64,
82 pub dup_ack_count: u32,
84 pub phase: CongestionPhase,
86 rtt_initialized: bool,
88}
89
90impl Default for CongestionState {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl CongestionState {
97 pub fn new() -> Self {
103 Self {
104 cwnd: MSS,
105 ssthresh: u32::MAX,
106 rtt_estimate_shifted: 0,
107 rtt_variance_shifted: 0,
108 rto: RTO_INITIAL_US,
109 dup_ack_count: 0,
110 phase: CongestionPhase::SlowStart,
111 rtt_initialized: false,
112 }
113 }
114
115 pub fn srtt_us(&self) -> u64 {
117 self.rtt_estimate_shifted >> SRTT_SHIFT
118 }
119
120 pub fn rttvar_us(&self) -> u64 {
122 self.rtt_variance_shifted >> RTTVAR_SHIFT
123 }
124
125 fn update_rtt(&mut self, rtt_us: u64) {
141 if rtt_us == 0 {
142 return;
143 }
144
145 if !self.rtt_initialized {
146 self.rtt_estimate_shifted = rtt_us << SRTT_SHIFT;
148 self.rtt_variance_shifted = (rtt_us / 2) << RTTVAR_SHIFT;
149 self.rtt_initialized = true;
150 } else {
151 let srtt_unshifted = self.rtt_estimate_shifted >> SRTT_SHIFT;
156 let abs_delta = srtt_unshifted.abs_diff(rtt_us);
157
158 let rttvar_adj = self.rtt_variance_shifted >> RTTVAR_SHIFT;
163 self.rtt_variance_shifted = self
164 .rtt_variance_shifted
165 .saturating_sub(rttvar_adj)
166 .saturating_add(abs_delta);
167
168 let srtt_adj = self.rtt_estimate_shifted >> SRTT_SHIFT;
172 self.rtt_estimate_shifted = self
173 .rtt_estimate_shifted
174 .saturating_sub(srtt_adj)
175 .saturating_add(rtt_us);
176 }
177
178 let srtt = self.rtt_estimate_shifted >> SRTT_SHIFT;
180 let rttvar = self.rtt_variance_shifted >> RTTVAR_SHIFT;
181 let k_rttvar = rttvar.saturating_mul(4);
182 let rto = srtt.saturating_add(core::cmp::max(CLOCK_GRANULARITY_US, k_rttvar));
183
184 self.rto = rto.clamp(RTO_MIN_US, RTO_MAX_US);
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct RenoController {
192 state: CongestionState,
193}
194
195impl Default for RenoController {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl RenoController {
202 pub fn new() -> Self {
204 Self {
205 state: CongestionState::new(),
206 }
207 }
208
209 pub fn state(&self) -> &CongestionState {
211 &self.state
212 }
213
214 pub fn rto_us(&self) -> u64 {
216 self.state.rto
217 }
218
219 pub fn phase(&self) -> CongestionPhase {
221 self.state.phase
222 }
223
224 pub fn dup_ack_count(&self) -> u32 {
226 self.state.dup_ack_count
227 }
228
229 pub fn srtt_us(&self) -> u64 {
231 self.state.srtt_us()
232 }
233
234 pub fn rttvar_us(&self) -> u64 {
236 self.state.rttvar_us()
237 }
238}
239
240impl CongestionController for RenoController {
241 fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64) {
248 self.state.update_rtt(rtt_us);
250
251 self.state.dup_ack_count = 0;
253
254 match self.state.phase {
255 CongestionPhase::SlowStart => {
256 self.state.cwnd = self.state.cwnd.saturating_add(MSS);
259
260 if self.state.cwnd >= self.state.ssthresh {
262 self.state.phase = CongestionPhase::CongestionAvoidance;
263 }
264 }
265 CongestionPhase::CongestionAvoidance => {
266 let increment = ((MSS as u64) * (MSS as u64) / (self.state.cwnd as u64)) as u32;
270 let increment = core::cmp::max(increment, 1);
272 self.state.cwnd = self.state.cwnd.saturating_add(increment);
273 }
274 CongestionPhase::FastRecovery => {
275 self.state.cwnd = self.state.ssthresh;
277 self.state.phase = CongestionPhase::CongestionAvoidance;
278 }
279 }
280
281 let _ = bytes_acked; }
283
284 fn on_duplicate_ack(&mut self) {
289 self.state.dup_ack_count += 1;
290
291 match self.state.phase {
292 CongestionPhase::SlowStart | CongestionPhase::CongestionAvoidance => {
293 if self.state.dup_ack_count == DUP_ACK_THRESHOLD {
294 self.state.ssthresh = core::cmp::max(self.state.cwnd / 2, 2 * MSS);
297 self.state.cwnd = self.state.ssthresh + 3 * MSS;
299 self.state.phase = CongestionPhase::FastRecovery;
300 }
301 }
302 CongestionPhase::FastRecovery => {
303 self.state.cwnd = self.state.cwnd.saturating_add(MSS);
305 }
306 }
307 }
308
309 fn on_timeout(&mut self) {
317 self.state.ssthresh = core::cmp::max(self.state.cwnd / 2, 2 * MSS);
319 self.state.cwnd = MSS;
321 self.state.dup_ack_count = 0;
323 self.state.phase = CongestionPhase::SlowStart;
325 self.state.rto = core::cmp::min(self.state.rto.saturating_mul(2), RTO_MAX_US);
327 }
328
329 fn congestion_window(&self) -> u32 {
330 self.state.cwnd
331 }
332
333 fn slow_start_threshold(&self) -> u32 {
334 self.state.ssthresh
335 }
336}
337
338const CUBIC_C_NUM: u64 = 410;
344const CUBIC_C_DEN: u64 = 1024;
345
346const CUBIC_BETA_NUM: u64 = 717;
348const CUBIC_BETA_DEN: u64 = 1024;
349
350const CUBIC_ONE_MINUS_BETA_NUM: u64 = 307;
352const CUBIC_ONE_MINUS_BETA_DEN: u64 = 1024;
353
354fn integer_cbrt(n: u64) -> u64 {
359 if n == 0 {
360 return 0;
361 }
362 if n < 8 {
363 return 1;
364 }
365
366 let bits = 64 - n.leading_zeros() as u64;
369 let mut x = 1u64 << bits.div_ceil(3);
370
371 loop {
373 let x2 = x.saturating_mul(x);
374 let x_new = if x2 == 0 {
375 x >> 1
377 } else {
378 (2 * x + n / x2) / 3
379 };
380 if x_new >= x {
381 break;
382 }
383 x = x_new;
384 }
385
386 if x.saturating_mul(x).saturating_mul(x) > n {
388 x -= 1;
389 }
390 x
391}
392
393#[derive(Debug, Clone)]
399pub struct CubicController {
400 state: CongestionState,
402 w_max: u32,
404 epoch_start: u64,
406 k_us: u64,
408 origin_point: u32,
410 tcp_cwnd: u32,
412 prev_w_max: u32,
414 elapsed_us: u64,
416}
417
418impl Default for CubicController {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424impl CubicController {
425 pub fn new() -> Self {
427 Self {
428 state: CongestionState::new(),
429 w_max: 0,
430 epoch_start: 0,
431 k_us: 0,
432 origin_point: 0,
433 tcp_cwnd: 0,
434 prev_w_max: 0,
435 elapsed_us: 0,
436 }
437 }
438
439 pub fn state(&self) -> &CongestionState {
441 &self.state
442 }
443
444 pub fn rto_us(&self) -> u64 {
446 self.state.rto
447 }
448
449 pub fn phase(&self) -> CongestionPhase {
451 self.state.phase
452 }
453
454 pub fn dup_ack_count(&self) -> u32 {
456 self.state.dup_ack_count
457 }
458
459 pub fn srtt_us(&self) -> u64 {
461 self.state.srtt_us()
462 }
463
464 fn compute_k(&self, w_max_segs: u64) -> u64 {
471 let numerator = w_max_segs
476 .saturating_mul(CUBIC_BETA_NUM)
477 .saturating_mul(CUBIC_C_DEN);
478 let denominator = CUBIC_BETA_DEN.saturating_mul(CUBIC_C_NUM);
479 let k_cubed = if denominator == 0 {
480 0
481 } else {
482 numerator / denominator
483 };
484 let k_segs = integer_cbrt(k_cubed);
485
486 let srtt = self.state.srtt_us();
488 let rtt = if srtt > 0 { srtt } else { 100_000 }; k_segs.saturating_mul(rtt)
490 }
491
492 fn cubic_window(&self, t_us: u64) -> u32 {
498 let rtt = {
499 let s = self.state.srtt_us();
500 if s > 0 {
501 s
502 } else {
503 100_000
504 }
505 };
506
507 let t_rtt_fp = if rtt > 0 {
510 t_us.saturating_mul(1024) / rtt
511 } else {
512 0
513 };
514 let k_rtt_fp = if rtt > 0 {
515 self.k_us.saturating_mul(1024) / rtt
516 } else {
517 0
518 };
519
520 let (diff_fp, negative) = if t_rtt_fp >= k_rtt_fp {
522 (t_rtt_fp - k_rtt_fp, false)
523 } else {
524 (k_rtt_fp - t_rtt_fp, true)
525 };
526
527 let diff2 = diff_fp.saturating_mul(diff_fp) / 1024; let diff3 = diff2.saturating_mul(diff_fp) / 1024; let c_diff3 = diff3.saturating_mul(CUBIC_C_NUM) / CUBIC_C_DEN;
537
538 let origin_segs = self.origin_point as u64;
539 let w_segs = if negative {
540 origin_segs.saturating_sub(c_diff3)
541 } else {
542 origin_segs.saturating_add(c_diff3)
543 };
544
545 let w_bytes = w_segs.saturating_mul(MSS as u64);
547 if w_bytes > u32::MAX as u64 {
548 u32::MAX
549 } else {
550 core::cmp::max(w_bytes as u32, MSS)
551 }
552 }
553
554 fn tcp_friendly_window(&self, t_us: u64) -> u32 {
558 let rtt = {
559 let s = self.state.srtt_us();
560 if s > 0 {
561 s
562 } else {
563 100_000
564 }
565 };
566 let w_max_segs = self.origin_point as u64;
567
568 let base = w_max_segs.saturating_mul(CUBIC_ONE_MINUS_BETA_NUM) / CUBIC_ONE_MINUS_BETA_DEN;
570
571 let slope_num: u64 = 3 * CUBIC_BETA_NUM; let slope_den: u64 = 2 * CUBIC_BETA_DEN - CUBIC_BETA_NUM; let t_rtts = if rtt > 0 { t_us / rtt } else { 0 };
578
579 let increment = t_rtts.saturating_mul(slope_num) / slope_den;
580
581 let w_segs = base.saturating_add(increment);
582 let w_bytes = w_segs.saturating_mul(MSS as u64);
583 if w_bytes > u32::MAX as u64 {
584 u32::MAX
585 } else {
586 core::cmp::max(w_bytes as u32, MSS)
587 }
588 }
589
590 fn start_epoch(&mut self) {
592 let w_max_segs = (self.w_max as u64) / (MSS as u64);
593 self.origin_point = w_max_segs as u32;
594 self.k_us = self.compute_k(w_max_segs);
595 self.elapsed_us = 0;
596 self.tcp_cwnd = self.state.cwnd;
597 }
598}
599
600impl CongestionController for CubicController {
601 fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64) {
602 self.state.update_rtt(rtt_us);
603 self.state.dup_ack_count = 0;
604
605 match self.state.phase {
606 CongestionPhase::SlowStart => {
607 self.state.cwnd = self.state.cwnd.saturating_add(MSS);
608 if self.state.cwnd >= self.state.ssthresh {
609 self.state.phase = CongestionPhase::CongestionAvoidance;
610 if self.w_max == 0 {
612 self.w_max = self.state.cwnd;
613 }
614 self.start_epoch();
615 }
616 }
617 CongestionPhase::CongestionAvoidance => {
618 let rtt_sample = if rtt_us > 0 {
620 rtt_us
621 } else {
622 self.state.srtt_us()
623 };
624 if rtt_sample > 0 {
625 self.elapsed_us = self.elapsed_us.saturating_add(rtt_sample);
626 }
627
628 let w_cubic = self.cubic_window(self.elapsed_us);
630 let w_tcp = self.tcp_friendly_window(self.elapsed_us);
631
632 let target = core::cmp::max(w_cubic, w_tcp);
634
635 if target > self.state.cwnd {
637 let delta = target - self.state.cwnd;
638 let increment = ((MSS as u64) * (MSS as u64) / (self.state.cwnd as u64)) as u32;
639 let increment = core::cmp::min(core::cmp::max(increment, 1), delta);
640 self.state.cwnd = self.state.cwnd.saturating_add(increment);
641 } else {
642 }
646
647 self.tcp_cwnd = w_tcp;
648 }
649 CongestionPhase::FastRecovery => {
650 self.state.cwnd = self.state.ssthresh;
651 self.state.phase = CongestionPhase::CongestionAvoidance;
652 self.start_epoch();
653 }
654 }
655
656 let _ = bytes_acked;
657 }
658
659 fn on_duplicate_ack(&mut self) {
660 self.state.dup_ack_count += 1;
661
662 match self.state.phase {
663 CongestionPhase::SlowStart | CongestionPhase::CongestionAvoidance => {
664 if self.state.dup_ack_count == DUP_ACK_THRESHOLD {
665 let current_cwnd = self.state.cwnd;
667
668 if current_cwnd < self.prev_w_max {
671 self.w_max = ((current_cwnd as u64) * (CUBIC_BETA_DEN + CUBIC_BETA_NUM)
673 / (2 * CUBIC_BETA_DEN)) as u32;
674 } else {
675 self.w_max = current_cwnd;
676 }
677 self.prev_w_max = current_cwnd;
678
679 let new_cwnd = ((current_cwnd as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
682 self.state.ssthresh = core::cmp::max(new_cwnd, 2 * MSS);
683 self.state.cwnd = self.state.ssthresh + 3 * MSS;
684 self.state.phase = CongestionPhase::FastRecovery;
685
686 self.start_epoch();
687 }
688 }
689 CongestionPhase::FastRecovery => {
690 self.state.cwnd = self.state.cwnd.saturating_add(MSS);
691 }
692 }
693 }
694
695 fn on_timeout(&mut self) {
696 self.prev_w_max = self.w_max;
698 self.w_max = self.state.cwnd;
699
700 let new_ssthresh = ((self.state.cwnd as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
702 self.state.ssthresh = core::cmp::max(new_ssthresh, 2 * MSS);
703 self.state.cwnd = MSS;
704 self.state.dup_ack_count = 0;
705 self.state.phase = CongestionPhase::SlowStart;
706 self.state.rto = core::cmp::min(self.state.rto.saturating_mul(2), RTO_MAX_US);
707
708 self.epoch_start = 0;
710 self.elapsed_us = 0;
711 }
712
713 fn congestion_window(&self) -> u32 {
714 self.state.cwnd
715 }
716
717 fn slow_start_threshold(&self) -> u32 {
718 self.state.ssthresh
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725
726 #[test]
727 fn test_initial_state() {
728 let cc = RenoController::new();
729 assert_eq!(cc.congestion_window(), MSS);
730 assert_eq!(cc.slow_start_threshold(), u32::MAX);
731 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
732 assert_eq!(cc.dup_ack_count(), 0);
733 assert_eq!(cc.rto_us(), RTO_INITIAL_US);
734 }
735
736 #[test]
737 fn test_slow_start_growth() {
738 let mut cc = RenoController::new();
739 cc.on_ack(MSS, 50_000); assert_eq!(cc.congestion_window(), 2 * MSS);
742 cc.on_ack(MSS, 50_000);
743 assert_eq!(cc.congestion_window(), 3 * MSS);
744 cc.on_ack(MSS, 50_000);
745 assert_eq!(cc.congestion_window(), 4 * MSS);
746 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
747 }
748
749 #[test]
750 fn test_slow_start_to_congestion_avoidance() {
751 let mut cc = RenoController::new();
752 cc.state.ssthresh = 3 * MSS;
754 cc.on_ack(MSS, 50_000); assert_eq!(cc.phase(), CongestionPhase::SlowStart);
756 cc.on_ack(MSS, 50_000); assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
758 }
759
760 #[test]
761 fn test_congestion_avoidance_linear_growth() {
762 let mut cc = RenoController::new();
763 cc.state.cwnd = 4 * MSS;
765 cc.state.ssthresh = 4 * MSS;
766 cc.state.phase = CongestionPhase::CongestionAvoidance;
767
768 let initial_cwnd = cc.congestion_window();
769 cc.on_ack(MSS, 50_000);
771 let increment = cc.congestion_window() - initial_cwnd;
772
773 let expected = MSS / 4;
775 assert_eq!(increment, expected);
776 }
777
778 #[test]
779 fn test_congestion_avoidance_minimum_increment() {
780 let mut cc = RenoController::new();
781 cc.state.cwnd = u32::MAX / 2;
783 cc.state.ssthresh = MSS;
784 cc.state.phase = CongestionPhase::CongestionAvoidance;
785
786 let initial = cc.congestion_window();
787 cc.on_ack(MSS, 50_000);
788 assert!(cc.congestion_window() > initial);
790 }
791
792 #[test]
793 fn test_fast_retransmit_on_3_dup_acks() {
794 let mut cc = RenoController::new();
795 cc.state.cwnd = 10 * MSS;
796 cc.state.phase = CongestionPhase::CongestionAvoidance;
797
798 let original_cwnd = cc.congestion_window();
799
800 cc.on_duplicate_ack();
802 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
803 cc.on_duplicate_ack();
804 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
805 cc.on_duplicate_ack();
806 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
807
808 let expected_ssthresh = core::cmp::max(original_cwnd / 2, 2 * MSS);
810 assert_eq!(cc.slow_start_threshold(), expected_ssthresh);
811
812 assert_eq!(cc.congestion_window(), expected_ssthresh + 3 * MSS);
814 }
815
816 #[test]
817 fn test_fast_recovery_inflation() {
818 let mut cc = RenoController::new();
819 cc.state.cwnd = 10 * MSS;
820 cc.state.phase = CongestionPhase::CongestionAvoidance;
821
822 for _ in 0..3 {
824 cc.on_duplicate_ack();
825 }
826 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
827 let cwnd_after_fr = cc.congestion_window();
828
829 cc.on_duplicate_ack();
831 assert_eq!(cc.congestion_window(), cwnd_after_fr + MSS);
832 cc.on_duplicate_ack();
833 assert_eq!(cc.congestion_window(), cwnd_after_fr + 2 * MSS);
834 }
835
836 #[test]
837 fn test_fast_recovery_exit_on_new_ack() {
838 let mut cc = RenoController::new();
839 cc.state.cwnd = 10 * MSS;
840 cc.state.phase = CongestionPhase::CongestionAvoidance;
841
842 for _ in 0..3 {
844 cc.on_duplicate_ack();
845 }
846 let ssthresh = cc.slow_start_threshold();
847
848 cc.on_ack(MSS, 50_000);
850 assert_eq!(cc.congestion_window(), ssthresh);
851 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
852 assert_eq!(cc.dup_ack_count(), 0);
853 }
854
855 #[test]
856 fn test_timeout_resets_to_slow_start() {
857 let mut cc = RenoController::new();
858 cc.state.cwnd = 20 * MSS;
859 cc.state.ssthresh = 15 * MSS;
860 cc.state.phase = CongestionPhase::CongestionAvoidance;
861
862 let cwnd_before = cc.congestion_window();
863 cc.on_timeout();
864
865 assert_eq!(cc.congestion_window(), MSS);
867 assert_eq!(
869 cc.slow_start_threshold(),
870 core::cmp::max(cwnd_before / 2, 2 * MSS)
871 );
872 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
873 }
874
875 #[test]
876 fn test_timeout_doubles_rto() {
877 let mut cc = RenoController::new();
878 let rto_initial = cc.rto_us();
879
880 cc.on_timeout();
881 assert_eq!(cc.rto_us(), rto_initial * 2);
882
883 cc.on_timeout();
884 assert_eq!(cc.rto_us(), rto_initial * 4);
885 }
886
887 #[test]
888 fn test_rto_clamped_to_max() {
889 let mut cc = RenoController::new();
890 cc.state.rto = RTO_MAX_US;
892 cc.on_timeout();
893 assert_eq!(cc.rto_us(), RTO_MAX_US);
894 }
895
896 #[test]
897 fn test_rto_first_rtt_sample() {
898 let mut cc = RenoController::new();
899 let rtt = 100_000u64; cc.on_ack(MSS, rtt);
902
903 assert_eq!(cc.srtt_us(), rtt);
905 assert_eq!(cc.rttvar_us(), rtt / 2);
906
907 let expected_rto = rtt + 4 * (rtt / 2);
910 let expected_rto = expected_rto.clamp(RTO_MIN_US, RTO_MAX_US);
912 assert_eq!(cc.rto_us(), expected_rto);
913 }
914
915 #[test]
916 fn test_rto_subsequent_rtt_sample() {
917 let mut cc = RenoController::new();
918
919 cc.on_ack(MSS, 100_000);
921 let srtt_after_first = cc.srtt_us();
922 assert_eq!(srtt_after_first, 100_000);
923
924 cc.on_ack(MSS, 120_000);
926 let srtt_after_second = cc.srtt_us();
927
928 assert_eq!(srtt_after_second, 102_500);
931 }
932
933 #[test]
934 fn test_rto_minimum_enforced() {
935 let mut cc = RenoController::new();
936 cc.on_ack(MSS, 100); assert!(cc.rto_us() >= RTO_MIN_US);
939 }
940
941 #[test]
942 fn test_zero_rtt_ignored() {
943 let mut cc = RenoController::new();
944 let rto_before = cc.rto_us();
945 cc.on_ack(MSS, 0);
946 assert_eq!(cc.rto_us(), rto_before);
948 }
949
950 #[test]
951 fn test_ssthresh_minimum_2mss() {
952 let mut cc = RenoController::new();
953 cc.state.cwnd = MSS;
955 cc.on_timeout();
956 assert_eq!(cc.slow_start_threshold(), 2 * MSS);
957 }
958
959 #[test]
960 fn test_dup_ack_count_reset_on_new_ack() {
961 let mut cc = RenoController::new();
962 cc.on_duplicate_ack();
963 cc.on_duplicate_ack();
964 assert_eq!(cc.dup_ack_count(), 2);
965
966 cc.on_ack(MSS, 50_000);
968 assert_eq!(cc.dup_ack_count(), 0);
969 }
970
971 #[test]
972 fn test_full_congestion_cycle() {
973 let mut cc = RenoController::new();
974
975 for _ in 0..5 {
977 cc.on_ack(MSS, 50_000);
978 }
979 assert_eq!(cc.congestion_window(), 6 * MSS);
980 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
981
982 cc.on_timeout();
984 assert_eq!(cc.congestion_window(), MSS);
985 assert_eq!(
986 cc.slow_start_threshold(),
987 core::cmp::max(6 * MSS / 2, 2 * MSS)
988 );
989 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
990
991 let ssthresh = cc.slow_start_threshold();
993 while cc.phase() == CongestionPhase::SlowStart {
994 cc.on_ack(MSS, 50_000);
995 }
996 assert!(cc.congestion_window() >= ssthresh);
997 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
998
999 for _ in 0..3 {
1001 cc.on_duplicate_ack();
1002 }
1003 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1004
1005 cc.on_ack(MSS, 50_000);
1007 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1008 }
1009
1010 #[test]
1015 fn test_integer_cbrt_exact_cubes() {
1016 assert_eq!(integer_cbrt(0), 0);
1017 assert_eq!(integer_cbrt(1), 1);
1018 assert_eq!(integer_cbrt(8), 2);
1019 assert_eq!(integer_cbrt(27), 3);
1020 assert_eq!(integer_cbrt(64), 4);
1021 assert_eq!(integer_cbrt(125), 5);
1022 assert_eq!(integer_cbrt(1000), 10);
1023 assert_eq!(integer_cbrt(1_000_000), 100);
1024 assert_eq!(integer_cbrt(1_000_000_000), 1000);
1025 }
1026
1027 #[test]
1028 fn test_integer_cbrt_non_exact() {
1029 assert_eq!(integer_cbrt(2), 1);
1031 assert_eq!(integer_cbrt(7), 1);
1032 assert_eq!(integer_cbrt(9), 2);
1033 assert_eq!(integer_cbrt(26), 2);
1034 assert_eq!(integer_cbrt(63), 3);
1035 assert_eq!(integer_cbrt(100), 4);
1036 }
1038
1039 #[test]
1040 fn test_integer_cbrt_large_values() {
1041 let val = 1_000_000_000_000u64; let root = integer_cbrt(val);
1044 assert_eq!(root, 10_000); let root = integer_cbrt(u64::MAX);
1048 assert!(root >= 2_642_245);
1050 assert!(root
1052 .checked_mul(root)
1053 .and_then(|r2| r2.checked_mul(root))
1054 .is_some());
1055 let r1 = root + 1;
1057 assert!(r1
1058 .checked_mul(r1)
1059 .and_then(|r2| r2.checked_mul(r1))
1060 .is_none());
1061 }
1062
1063 #[test]
1064 fn test_cubic_initial_state() {
1065 let cc = CubicController::new();
1066 assert_eq!(cc.congestion_window(), MSS);
1067 assert_eq!(cc.slow_start_threshold(), u32::MAX);
1068 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1069 assert_eq!(cc.w_max, 0);
1070 }
1071
1072 #[test]
1073 fn test_cubic_slow_start_identical_to_reno() {
1074 let mut cubic = CubicController::new();
1075 let mut reno = RenoController::new();
1076
1077 for _ in 0..5 {
1079 cubic.on_ack(MSS, 50_000);
1080 reno.on_ack(MSS, 50_000);
1081 }
1082 assert_eq!(cubic.congestion_window(), reno.congestion_window());
1083 assert_eq!(cubic.phase(), CongestionPhase::SlowStart);
1084 }
1085
1086 #[test]
1087 fn test_cubic_loss_response_beta_07() {
1088 let mut cc = CubicController::new();
1089 cc.state.cwnd = 100 * MSS;
1091 cc.state.ssthresh = 50 * MSS;
1092 cc.state.phase = CongestionPhase::CongestionAvoidance;
1093 cc.w_max = 100 * MSS;
1094
1095 let cwnd_before = cc.congestion_window();
1096
1097 for _ in 0..3 {
1099 cc.on_duplicate_ack();
1100 }
1101
1102 let expected_ssthresh = ((cwnd_before as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
1104 assert_eq!(cc.slow_start_threshold(), expected_ssthresh);
1105
1106 assert_eq!(cc.w_max, cwnd_before);
1108 }
1109
1110 #[test]
1111 fn test_cubic_fast_convergence() {
1112 let mut cc = CubicController::new();
1113 cc.state.cwnd = 100 * MSS;
1114 cc.state.ssthresh = 50 * MSS;
1115 cc.state.phase = CongestionPhase::CongestionAvoidance;
1116 cc.prev_w_max = 120 * MSS; let cwnd_before = cc.congestion_window();
1119
1120 for _ in 0..3 {
1122 cc.on_duplicate_ack();
1123 }
1124
1125 let expected_w_max = ((cwnd_before as u64) * (CUBIC_BETA_DEN + CUBIC_BETA_NUM)
1127 / (2 * CUBIC_BETA_DEN)) as u32;
1128 assert_eq!(cc.w_max, expected_w_max);
1129 assert!(cc.w_max < cwnd_before);
1131 }
1132
1133 #[test]
1134 fn test_cubic_timeout_resets() {
1135 let mut cc = CubicController::new();
1136 cc.state.cwnd = 50 * MSS;
1137 cc.state.phase = CongestionPhase::CongestionAvoidance;
1138
1139 cc.on_timeout();
1140
1141 assert_eq!(cc.congestion_window(), MSS);
1142 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1143 let expected = ((50 * MSS as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
1145 assert_eq!(cc.slow_start_threshold(), expected);
1146 }
1147
1148 #[test]
1149 fn test_cubic_growth_after_loss() {
1150 let mut cc = CubicController::new();
1151 cc.state.cwnd = 100 * MSS;
1153 cc.state.ssthresh = 50 * MSS;
1154 cc.state.phase = CongestionPhase::CongestionAvoidance;
1155
1156 cc.state.update_rtt(50_000); for _ in 0..3 {
1161 cc.on_duplicate_ack();
1162 }
1163 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1164
1165 cc.on_ack(MSS, 50_000);
1167 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1168
1169 let cwnd_after_loss = cc.congestion_window();
1170
1171 for _ in 0..20 {
1173 cc.on_ack(MSS, 50_000);
1174 }
1175 assert!(
1176 cc.congestion_window() > cwnd_after_loss,
1177 "cwnd should grow after loss: {} vs {}",
1178 cc.congestion_window(),
1179 cwnd_after_loss
1180 );
1181 }
1182
1183 #[test]
1184 fn test_cubic_tcp_friendly_region() {
1185 let mut cc = CubicController::new();
1188 cc.state.update_rtt(50_000);
1189 cc.state.cwnd = 10 * MSS;
1190 cc.state.ssthresh = 10 * MSS;
1191 cc.state.phase = CongestionPhase::CongestionAvoidance;
1192 cc.w_max = 10 * MSS;
1193 cc.origin_point = 10;
1194 cc.k_us = cc.compute_k(10);
1195 cc.elapsed_us = 0;
1196
1197 let w_tcp = cc.tcp_friendly_window(0);
1199 assert!(w_tcp >= MSS);
1200
1201 let w_tcp_later = cc.tcp_friendly_window(500_000); assert!(w_tcp_later >= w_tcp);
1204 }
1205
1206 #[test]
1207 fn test_cubic_window_concave_then_convex() {
1208 let mut cc = CubicController::new();
1211 cc.state.update_rtt(50_000);
1212 cc.w_max = 100 * MSS;
1213 cc.origin_point = 100;
1214 cc.k_us = cc.compute_k(100);
1215
1216 if cc.k_us > 200_000 {
1218 let w_early = cc.cubic_window(cc.k_us / 4);
1219 assert!(
1220 w_early < cc.w_max,
1221 "early window {} should be below w_max {}",
1222 w_early,
1223 cc.w_max
1224 );
1225 }
1226
1227 let w_at_k = cc.cubic_window(cc.k_us);
1229 let tolerance = 5 * MSS; let diff = if w_at_k > cc.w_max {
1231 w_at_k - cc.w_max
1232 } else {
1233 cc.w_max - w_at_k
1234 };
1235 assert!(
1236 diff <= tolerance,
1237 "window at K ({}) should be close to w_max ({}), diff={}",
1238 w_at_k,
1239 cc.w_max,
1240 diff
1241 );
1242
1243 let w_late = cc.cubic_window(cc.k_us * 3);
1245 assert!(
1246 w_late > cc.w_max,
1247 "late window {} should exceed w_max {}",
1248 w_late,
1249 cc.w_max
1250 );
1251 }
1252
1253 #[test]
1254 fn test_cubic_full_congestion_cycle() {
1255 let mut cc = CubicController::new();
1256
1257 for _ in 0..5 {
1259 cc.on_ack(MSS, 50_000);
1260 }
1261 assert_eq!(cc.congestion_window(), 6 * MSS);
1262 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1263
1264 cc.on_timeout();
1266 assert_eq!(cc.congestion_window(), MSS);
1267 assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1268 assert!(cc.w_max > 0);
1269
1270 while cc.phase() == CongestionPhase::SlowStart {
1272 cc.on_ack(MSS, 50_000);
1273 }
1274 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1275
1276 let cwnd_at_ca = cc.congestion_window();
1278 for _ in 0..10 {
1279 cc.on_ack(MSS, 50_000);
1280 }
1281 assert!(cc.congestion_window() >= cwnd_at_ca);
1282
1283 for _ in 0..3 {
1285 cc.on_duplicate_ack();
1286 }
1287 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1288
1289 cc.on_ack(MSS, 50_000);
1291 assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1292 }
1293
1294 #[test]
1295 fn test_cubic_recovery_inflation() {
1296 let mut cc = CubicController::new();
1297 cc.state.cwnd = 20 * MSS;
1298 cc.state.phase = CongestionPhase::CongestionAvoidance;
1299
1300 for _ in 0..3 {
1302 cc.on_duplicate_ack();
1303 }
1304 assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1305 let cwnd_fr = cc.congestion_window();
1306
1307 cc.on_duplicate_ack();
1309 assert_eq!(cc.congestion_window(), cwnd_fr + MSS);
1310 cc.on_duplicate_ack();
1311 assert_eq!(cc.congestion_window(), cwnd_fr + 2 * MSS);
1312 }
1313}