⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
congestion.rs

1//! TCP Congestion Control: Reno and Cubic
2//!
3//! Implements RFC 5681 Reno (slow start, congestion avoidance, fast retransmit,
4//! fast recovery) and RFC 8312 Cubic congestion control. Uses Jacobson's
5//! algorithm (RFC 6298) for RTO estimation. All arithmetic is
6//! integer/fixed-point (no floating point) for `no_std` compatibility.
7
8#![allow(dead_code)]
9
10/// Maximum Segment Size (standard Ethernet)
11const MSS: u32 = 1460;
12
13/// Minimum RTO in microseconds (1 second per RFC 6298)
14const RTO_MIN_US: u64 = 1_000_000;
15
16/// Maximum RTO in microseconds (60 seconds per RFC 6298)
17const RTO_MAX_US: u64 = 60_000_000;
18
19/// Initial RTO before any RTT measurement (1 second per RFC 6298)
20const RTO_INITIAL_US: u64 = 1_000_000;
21
22/// Clock granularity in microseconds (1ms)
23const CLOCK_GRANULARITY_US: u64 = 1_000;
24
25/// Duplicate ACK threshold for fast retransmit
26const DUP_ACK_THRESHOLD: u32 = 3;
27
28/// Fixed-point shift for Jacobson's algorithm (alpha = 1/8, beta = 1/4)
29/// SRTT and RTTVAR are stored shifted left by SHIFT bits for precision.
30const SRTT_SHIFT: u32 = 3; // alpha = 1/8 = 1/(2^3)
31const RTTVAR_SHIFT: u32 = 2; // beta = 1/4 = 1/(2^2)
32
33/// Congestion control phase
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum CongestionPhase {
36    /// Exponential growth: cwnd doubles per RTT
37    SlowStart,
38    /// Linear growth: cwnd increases by ~MSS per RTT
39    CongestionAvoidance,
40    /// After fast retransmit: inflate cwnd with dup ACKs, deflate on new ACK
41    FastRecovery,
42}
43
44/// Congestion controller trait
45///
46/// Provides a pluggable interface for congestion control algorithms.
47/// The default implementation is TCP Reno (`RenoController`).
48pub trait CongestionController: Send + Sync {
49    /// Called when new data is acknowledged.
50    ///
51    /// `bytes_acked`: number of newly acknowledged bytes
52    /// `rtt_us`: measured round-trip time in microseconds (0 if unavailable)
53    fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64);
54
55    /// Called when a duplicate ACK is received.
56    fn on_duplicate_ack(&mut self);
57
58    /// Called when a retransmission timeout fires.
59    fn on_timeout(&mut self);
60
61    /// Returns the current congestion window in bytes.
62    fn congestion_window(&self) -> u32;
63
64    /// Returns the current slow-start threshold in bytes.
65    fn slow_start_threshold(&self) -> u32;
66}
67
68/// TCP Reno congestion control state
69#[derive(Debug, Clone)]
70pub struct CongestionState {
71    /// Congestion window (bytes)
72    pub cwnd: u32,
73    /// Slow-start threshold (bytes)
74    pub ssthresh: u32,
75    /// Smoothed RTT estimate (microseconds, fixed-point shifted left by
76    /// SRTT_SHIFT)
77    rtt_estimate_shifted: u64,
78    /// RTT variance (microseconds, fixed-point shifted left by RTTVAR_SHIFT)
79    rtt_variance_shifted: u64,
80    /// Retransmission timeout (microseconds)
81    pub rto: u64,
82    /// Duplicate ACK counter
83    pub dup_ack_count: u32,
84    /// Current congestion phase
85    pub phase: CongestionPhase,
86    /// Whether we have taken a first RTT sample
87    rtt_initialized: bool,
88}
89
90impl Default for CongestionState {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96impl CongestionState {
97    /// Create a new congestion control state with default initial values.
98    ///
99    /// - cwnd starts at 1 MSS (1460 bytes)
100    /// - ssthresh starts at u32::MAX (effectively infinite)
101    /// - RTO starts at 1 second
102    pub fn new() -> Self {
103        Self {
104            cwnd: MSS,
105            ssthresh: u32::MAX,
106            rtt_estimate_shifted: 0,
107            rtt_variance_shifted: 0,
108            rto: RTO_INITIAL_US,
109            dup_ack_count: 0,
110            phase: CongestionPhase::SlowStart,
111            rtt_initialized: false,
112        }
113    }
114
115    /// Return the smoothed RTT estimate in microseconds (unshifted).
116    pub fn srtt_us(&self) -> u64 {
117        self.rtt_estimate_shifted >> SRTT_SHIFT
118    }
119
120    /// Return the RTT variance in microseconds (unshifted).
121    pub fn rttvar_us(&self) -> u64 {
122        self.rtt_variance_shifted >> RTTVAR_SHIFT
123    }
124
125    /// Update RTT estimates using Jacobson's algorithm (RFC 6298).
126    ///
127    /// All arithmetic uses integer/fixed-point operations (no floats).
128    ///
129    /// On the first sample:
130    ///   SRTT = R
131    ///   RTTVAR = R / 2
132    ///
133    /// On subsequent samples:
134    ///   RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R|
135    ///          = RTTVAR - RTTVAR/4 + |SRTT - R|/4
136    ///   SRTT   = (1 - alpha) * SRTT + alpha * R
137    ///          = SRTT - SRTT/8 + R/8
138    ///
139    /// RTO = SRTT + max(G, 4 * RTTVAR)
140    fn update_rtt(&mut self, rtt_us: u64) {
141        if rtt_us == 0 {
142            return;
143        }
144
145        if !self.rtt_initialized {
146            // First RTT measurement (RFC 6298 Section 2.2)
147            self.rtt_estimate_shifted = rtt_us << SRTT_SHIFT;
148            self.rtt_variance_shifted = (rtt_us / 2) << RTTVAR_SHIFT;
149            self.rtt_initialized = true;
150        } else {
151            // Subsequent measurements (RFC 6298 Section 2.3)
152            // Work with shifted values for precision
153
154            // |SRTT - R| in unshifted microseconds
155            let srtt_unshifted = self.rtt_estimate_shifted >> SRTT_SHIFT;
156            let abs_delta = srtt_unshifted.abs_diff(rtt_us);
157
158            // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R|
159            // In shifted form: RTTVAR_s = RTTVAR_s - RTTVAR_s/4 + (|delta| <<
160            // RTTVAR_SHIFT)/4 = RTTVAR_s - RTTVAR_s/4 + |delta|
161            // (since |delta| << RTTVAR_SHIFT >> RTTVAR_SHIFT = |delta|)
162            let rttvar_adj = self.rtt_variance_shifted >> RTTVAR_SHIFT;
163            self.rtt_variance_shifted = self
164                .rtt_variance_shifted
165                .saturating_sub(rttvar_adj)
166                .saturating_add(abs_delta);
167
168            // SRTT = (1 - alpha) * SRTT + alpha * R
169            // In shifted form: SRTT_s = SRTT_s - SRTT_s/8 + (R << SRTT_SHIFT)/8
170            // = SRTT_s - SRTT_s/8 + R
171            let srtt_adj = self.rtt_estimate_shifted >> SRTT_SHIFT;
172            self.rtt_estimate_shifted = self
173                .rtt_estimate_shifted
174                .saturating_sub(srtt_adj)
175                .saturating_add(rtt_us);
176        }
177
178        // RTO = SRTT + max(G, 4 * RTTVAR)
179        let srtt = self.rtt_estimate_shifted >> SRTT_SHIFT;
180        let rttvar = self.rtt_variance_shifted >> RTTVAR_SHIFT;
181        let k_rttvar = rttvar.saturating_mul(4);
182        let rto = srtt.saturating_add(core::cmp::max(CLOCK_GRANULARITY_US, k_rttvar));
183
184        // Clamp to [RTO_MIN, RTO_MAX]
185        self.rto = rto.clamp(RTO_MIN_US, RTO_MAX_US);
186    }
187}
188
189/// TCP Reno congestion controller
190#[derive(Debug, Clone)]
191pub struct RenoController {
192    state: CongestionState,
193}
194
195impl Default for RenoController {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl RenoController {
202    /// Create a new Reno congestion controller.
203    pub fn new() -> Self {
204        Self {
205            state: CongestionState::new(),
206        }
207    }
208
209    /// Access the underlying congestion state.
210    pub fn state(&self) -> &CongestionState {
211        &self.state
212    }
213
214    /// Return the current RTO in microseconds.
215    pub fn rto_us(&self) -> u64 {
216        self.state.rto
217    }
218
219    /// Return the current congestion phase.
220    pub fn phase(&self) -> CongestionPhase {
221        self.state.phase
222    }
223
224    /// Return the duplicate ACK count.
225    pub fn dup_ack_count(&self) -> u32 {
226        self.state.dup_ack_count
227    }
228
229    /// Return the smoothed RTT in microseconds.
230    pub fn srtt_us(&self) -> u64 {
231        self.state.srtt_us()
232    }
233
234    /// Return the RTT variance in microseconds.
235    pub fn rttvar_us(&self) -> u64 {
236        self.state.rttvar_us()
237    }
238}
239
240impl CongestionController for RenoController {
241    /// Handle acknowledgment of new data.
242    ///
243    /// In SlowStart: cwnd += MSS per ACK (doubles per RTT)
244    /// In CongestionAvoidance: cwnd += MSS * MSS / cwnd per ACK (~MSS per RTT)
245    /// In FastRecovery: transition to CongestionAvoidance, cwnd = ssthresh
246    /// (deflate)
247    fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64) {
248        // Update RTT estimate
249        self.state.update_rtt(rtt_us);
250
251        // Reset duplicate ACK counter on new ACK
252        self.state.dup_ack_count = 0;
253
254        match self.state.phase {
255            CongestionPhase::SlowStart => {
256                // Exponential growth: add MSS for each ACK
257                // (This effectively doubles cwnd per RTT when all segments are acked)
258                self.state.cwnd = self.state.cwnd.saturating_add(MSS);
259
260                // Transition to congestion avoidance when cwnd >= ssthresh
261                if self.state.cwnd >= self.state.ssthresh {
262                    self.state.phase = CongestionPhase::CongestionAvoidance;
263                }
264            }
265            CongestionPhase::CongestionAvoidance => {
266                // Linear growth: increment by MSS * MSS / cwnd per ACK
267                // This gives approximately MSS bytes growth per RTT.
268                // Use u64 intermediate to avoid overflow in MSS * MSS.
269                let increment = ((MSS as u64) * (MSS as u64) / (self.state.cwnd as u64)) as u32;
270                // Ensure at least 1 byte increase to guarantee progress
271                let increment = core::cmp::max(increment, 1);
272                self.state.cwnd = self.state.cwnd.saturating_add(increment);
273            }
274            CongestionPhase::FastRecovery => {
275                // New ACK received during fast recovery: deflate cwnd
276                self.state.cwnd = self.state.ssthresh;
277                self.state.phase = CongestionPhase::CongestionAvoidance;
278            }
279        }
280
281        let _ = bytes_acked; // Available for future use (e.g., ABC)
282    }
283
284    /// Handle a duplicate ACK.
285    ///
286    /// After 3 duplicate ACKs: enter fast retransmit / fast recovery.
287    /// During fast recovery: inflate cwnd by MSS per additional dup ACK.
288    fn on_duplicate_ack(&mut self) {
289        self.state.dup_ack_count += 1;
290
291        match self.state.phase {
292            CongestionPhase::SlowStart | CongestionPhase::CongestionAvoidance => {
293                if self.state.dup_ack_count == DUP_ACK_THRESHOLD {
294                    // Enter fast retransmit / fast recovery
295                    // ssthresh = max(cwnd / 2, 2 * MSS)
296                    self.state.ssthresh = core::cmp::max(self.state.cwnd / 2, 2 * MSS);
297                    // cwnd = ssthresh + 3 * MSS (account for the 3 dup ACKs)
298                    self.state.cwnd = self.state.ssthresh + 3 * MSS;
299                    self.state.phase = CongestionPhase::FastRecovery;
300                }
301            }
302            CongestionPhase::FastRecovery => {
303                // Each additional dup ACK: inflate cwnd by MSS
304                self.state.cwnd = self.state.cwnd.saturating_add(MSS);
305            }
306        }
307    }
308
309    /// Handle a retransmission timeout.
310    ///
311    /// This is the most severe congestion signal:
312    /// - ssthresh = max(cwnd / 2, 2 * MSS)
313    /// - cwnd = 1 MSS
314    /// - Back to slow start
315    /// - Double the RTO (exponential backoff)
316    fn on_timeout(&mut self) {
317        // ssthresh = max(cwnd / 2, 2 * MSS)
318        self.state.ssthresh = core::cmp::max(self.state.cwnd / 2, 2 * MSS);
319        // Reset cwnd to 1 MSS
320        self.state.cwnd = MSS;
321        // Reset dup ACK counter
322        self.state.dup_ack_count = 0;
323        // Return to slow start
324        self.state.phase = CongestionPhase::SlowStart;
325        // Exponential backoff: double the RTO, clamped to max
326        self.state.rto = core::cmp::min(self.state.rto.saturating_mul(2), RTO_MAX_US);
327    }
328
329    fn congestion_window(&self) -> u32 {
330        self.state.cwnd
331    }
332
333    fn slow_start_threshold(&self) -> u32 {
334        self.state.ssthresh
335    }
336}
337
338// ---------------------------------------------------------------------------
339// TCP Cubic Congestion Control (RFC 8312)
340// ---------------------------------------------------------------------------
341
342/// Cubic parameter C = 0.4, represented as 410/1024 in fixed-point.
343const CUBIC_C_NUM: u64 = 410;
344const CUBIC_C_DEN: u64 = 1024;
345
346/// Cubic parameter beta = 0.7, represented as 717/1024.
347const CUBIC_BETA_NUM: u64 = 717;
348const CUBIC_BETA_DEN: u64 = 1024;
349
350/// One minus beta = 0.3, represented as 307/1024.
351const CUBIC_ONE_MINUS_BETA_NUM: u64 = 307;
352const CUBIC_ONE_MINUS_BETA_DEN: u64 = 1024;
353
354/// Integer cube root via Newton's method.
355///
356/// Returns the integer cube root of `n` (i.e., floor(n^(1/3))).
357/// Uses pure integer arithmetic -- no floating point.
358fn integer_cbrt(n: u64) -> u64 {
359    if n == 0 {
360        return 0;
361    }
362    if n < 8 {
363        return 1;
364    }
365
366    // Initial estimate: start from a power-of-two upper bound.
367    // bit_length / 3 gives a reasonable starting exponent.
368    let bits = 64 - n.leading_zeros() as u64;
369    let mut x = 1u64 << bits.div_ceil(3);
370
371    // Newton's iteration: x_{n+1} = (2*x_n + n / x_n^2) / 3
372    loop {
373        let x2 = x.saturating_mul(x);
374        let x_new = if x2 == 0 {
375            // Overflow guard
376            x >> 1
377        } else {
378            (2 * x + n / x2) / 3
379        };
380        if x_new >= x {
381            break;
382        }
383        x = x_new;
384    }
385
386    // Newton can overshoot by 1; verify and correct.
387    if x.saturating_mul(x).saturating_mul(x) > n {
388        x -= 1;
389    }
390    x
391}
392
393/// TCP Cubic congestion controller (RFC 8312).
394///
395/// Cubic uses a cubic function of elapsed time since the last congestion event
396/// to set the congestion window, providing better bandwidth utilization on
397/// high-BDP networks than Reno while remaining TCP-friendly on low-BDP paths.
398#[derive(Debug, Clone)]
399pub struct CubicController {
400    /// Underlying congestion state (cwnd, ssthresh, RTT, RTO, phase).
401    state: CongestionState,
402    /// cwnd at last loss event (in bytes).
403    w_max: u32,
404    /// Timestamp of last congestion event (microseconds since boot).
405    epoch_start: u64,
406    /// K value in microseconds (time for cubic to reach w_max).
407    k_us: u64,
408    /// Origin point (W_max) for the current cubic epoch, in segments.
409    origin_point: u32,
410    /// TCP-friendly window estimate (bytes).
411    tcp_cwnd: u32,
412    /// Previous w_max for fast convergence.
413    prev_w_max: u32,
414    /// Elapsed time accumulator in microseconds (driven by RTT samples).
415    elapsed_us: u64,
416}
417
418impl Default for CubicController {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424impl CubicController {
425    /// Create a new Cubic congestion controller.
426    pub fn new() -> Self {
427        Self {
428            state: CongestionState::new(),
429            w_max: 0,
430            epoch_start: 0,
431            k_us: 0,
432            origin_point: 0,
433            tcp_cwnd: 0,
434            prev_w_max: 0,
435            elapsed_us: 0,
436        }
437    }
438
439    /// Access the underlying congestion state.
440    pub fn state(&self) -> &CongestionState {
441        &self.state
442    }
443
444    /// Return the current RTO in microseconds.
445    pub fn rto_us(&self) -> u64 {
446        self.state.rto
447    }
448
449    /// Return the current congestion phase.
450    pub fn phase(&self) -> CongestionPhase {
451        self.state.phase
452    }
453
454    /// Return the duplicate ACK count.
455    pub fn dup_ack_count(&self) -> u32 {
456        self.state.dup_ack_count
457    }
458
459    /// Return the smoothed RTT in microseconds.
460    pub fn srtt_us(&self) -> u64 {
461        self.state.srtt_us()
462    }
463
464    /// Compute K (time to reach w_max) in microseconds.
465    ///
466    /// K = cbrt(W_max * beta / C) in segments, then converted to microseconds
467    /// using the current SRTT.
468    ///
469    /// All math is integer-only with intermediate u64 to prevent overflow.
470    fn compute_k(&self, w_max_segs: u64) -> u64 {
471        // K_segs_cubed = w_max_segs * beta / C
472        //              = w_max_segs * (CUBIC_BETA_NUM / CUBIC_BETA_DEN) / (CUBIC_C_NUM
473        // / CUBIC_C_DEN)              = w_max_segs * CUBIC_BETA_NUM *
474        // CUBIC_C_DEN / (CUBIC_BETA_DEN * CUBIC_C_NUM)
475        let numerator = w_max_segs
476            .saturating_mul(CUBIC_BETA_NUM)
477            .saturating_mul(CUBIC_C_DEN);
478        let denominator = CUBIC_BETA_DEN.saturating_mul(CUBIC_C_NUM);
479        let k_cubed = if denominator == 0 {
480            0
481        } else {
482            numerator / denominator
483        };
484        let k_segs = integer_cbrt(k_cubed);
485
486        // Convert from RTT units to microseconds.
487        let srtt = self.state.srtt_us();
488        let rtt = if srtt > 0 { srtt } else { 100_000 }; // default 100ms if unknown
489        k_segs.saturating_mul(rtt)
490    }
491
492    /// Compute the Cubic window W(t) in bytes.
493    ///
494    /// W(t) = C * (t - K)^3 + W_max
495    ///
496    /// `t_us` is elapsed time in microseconds since the congestion event.
497    fn cubic_window(&self, t_us: u64) -> u32 {
498        let rtt = {
499            let s = self.state.srtt_us();
500            if s > 0 {
501                s
502            } else {
503                100_000
504            }
505        };
506
507        // Convert t from microseconds to RTT units (fixed-point 10-bit fraction).
508        // t_rtt_fp = t_us * 1024 / rtt
509        let t_rtt_fp = if rtt > 0 {
510            t_us.saturating_mul(1024) / rtt
511        } else {
512            0
513        };
514        let k_rtt_fp = if rtt > 0 {
515            self.k_us.saturating_mul(1024) / rtt
516        } else {
517            0
518        };
519
520        // (t - K) in fixed-point RTT units, may be negative
521        let (diff_fp, negative) = if t_rtt_fp >= k_rtt_fp {
522            (t_rtt_fp - k_rtt_fp, false)
523        } else {
524            (k_rtt_fp - t_rtt_fp, true)
525        };
526
527        // (t - K)^3 in fixed-point: cube then un-scale by 1024^3 -> 1024^2
528        // We compute diff^3 / 1024^2 to stay in 1024-scaled result.
529        // To avoid overflow: diff is at most ~10^7 for 1024-scaled,
530        // diff^3 can be huge. Use step-by-step division.
531        let diff2 = diff_fp.saturating_mul(diff_fp) / 1024; // scale back one factor
532        let diff3 = diff2.saturating_mul(diff_fp) / 1024; // scale back another factor
533                                                          // diff3 is now in base (RTT) units (unscaled)
534
535        // W_cubic = C * diff3 + origin_point  (all in segments)
536        let c_diff3 = diff3.saturating_mul(CUBIC_C_NUM) / CUBIC_C_DEN;
537
538        let origin_segs = self.origin_point as u64;
539        let w_segs = if negative {
540            origin_segs.saturating_sub(c_diff3)
541        } else {
542            origin_segs.saturating_add(c_diff3)
543        };
544
545        // Convert segments to bytes, clamp to u32
546        let w_bytes = w_segs.saturating_mul(MSS as u64);
547        if w_bytes > u32::MAX as u64 {
548            u32::MAX
549        } else {
550            core::cmp::max(w_bytes as u32, MSS)
551        }
552    }
553
554    /// Compute the TCP-friendly (Reno-equivalent) window in bytes.
555    ///
556    /// W_tcp = W_max * (1 - beta) + 3 * beta / (2 - beta) * t / RTT
557    fn tcp_friendly_window(&self, t_us: u64) -> u32 {
558        let rtt = {
559            let s = self.state.srtt_us();
560            if s > 0 {
561                s
562            } else {
563                100_000
564            }
565        };
566        let w_max_segs = self.origin_point as u64;
567
568        // Base: W_max * (1 - beta) in segments
569        let base = w_max_segs.saturating_mul(CUBIC_ONE_MINUS_BETA_NUM) / CUBIC_ONE_MINUS_BETA_DEN;
570
571        // Slope: 3 * beta / (2 - beta) per RTT
572        // = 3 * 717 / (2 * 1024 - 717) = 2151 / 1331
573        let slope_num: u64 = 3 * CUBIC_BETA_NUM; // 2151
574        let slope_den: u64 = 2 * CUBIC_BETA_DEN - CUBIC_BETA_NUM; // 1331
575
576        // t in RTT units
577        let t_rtts = if rtt > 0 { t_us / rtt } else { 0 };
578
579        let increment = t_rtts.saturating_mul(slope_num) / slope_den;
580
581        let w_segs = base.saturating_add(increment);
582        let w_bytes = w_segs.saturating_mul(MSS as u64);
583        if w_bytes > u32::MAX as u64 {
584            u32::MAX
585        } else {
586            core::cmp::max(w_bytes as u32, MSS)
587        }
588    }
589
590    /// Start a new cubic epoch after a loss event.
591    fn start_epoch(&mut self) {
592        let w_max_segs = (self.w_max as u64) / (MSS as u64);
593        self.origin_point = w_max_segs as u32;
594        self.k_us = self.compute_k(w_max_segs);
595        self.elapsed_us = 0;
596        self.tcp_cwnd = self.state.cwnd;
597    }
598}
599
600impl CongestionController for CubicController {
601    fn on_ack(&mut self, bytes_acked: u32, rtt_us: u64) {
602        self.state.update_rtt(rtt_us);
603        self.state.dup_ack_count = 0;
604
605        match self.state.phase {
606            CongestionPhase::SlowStart => {
607                self.state.cwnd = self.state.cwnd.saturating_add(MSS);
608                if self.state.cwnd >= self.state.ssthresh {
609                    self.state.phase = CongestionPhase::CongestionAvoidance;
610                    // Begin cubic epoch when entering CA
611                    if self.w_max == 0 {
612                        self.w_max = self.state.cwnd;
613                    }
614                    self.start_epoch();
615                }
616            }
617            CongestionPhase::CongestionAvoidance => {
618                // Advance elapsed time by one RTT sample
619                let rtt_sample = if rtt_us > 0 {
620                    rtt_us
621                } else {
622                    self.state.srtt_us()
623                };
624                if rtt_sample > 0 {
625                    self.elapsed_us = self.elapsed_us.saturating_add(rtt_sample);
626                }
627
628                // Compute cubic window and TCP-friendly window
629                let w_cubic = self.cubic_window(self.elapsed_us);
630                let w_tcp = self.tcp_friendly_window(self.elapsed_us);
631
632                // Use the larger of cubic and TCP-friendly (ensures fairness)
633                let target = core::cmp::max(w_cubic, w_tcp);
634
635                // Increase toward target: add MSS * MSS / cwnd per ACK (bounded)
636                if target > self.state.cwnd {
637                    let delta = target - self.state.cwnd;
638                    let increment = ((MSS as u64) * (MSS as u64) / (self.state.cwnd as u64)) as u32;
639                    let increment = core::cmp::min(core::cmp::max(increment, 1), delta);
640                    self.state.cwnd = self.state.cwnd.saturating_add(increment);
641                } else {
642                    // Cubic says reduce, but don't decrease below current -
643                    // just hold (Cubic only decreases on
644                    // loss, not proactively)
645                }
646
647                self.tcp_cwnd = w_tcp;
648            }
649            CongestionPhase::FastRecovery => {
650                self.state.cwnd = self.state.ssthresh;
651                self.state.phase = CongestionPhase::CongestionAvoidance;
652                self.start_epoch();
653            }
654        }
655
656        let _ = bytes_acked;
657    }
658
659    fn on_duplicate_ack(&mut self) {
660        self.state.dup_ack_count += 1;
661
662        match self.state.phase {
663            CongestionPhase::SlowStart | CongestionPhase::CongestionAvoidance => {
664                if self.state.dup_ack_count == DUP_ACK_THRESHOLD {
665                    // Save w_max before reduction
666                    let current_cwnd = self.state.cwnd;
667
668                    // Fast convergence: if new w_max < previous w_max,
669                    // reduce w_max further to converge faster.
670                    if current_cwnd < self.prev_w_max {
671                        // w_max = cwnd * (1 + beta) / 2 = cwnd * 1717 / 2048
672                        self.w_max = ((current_cwnd as u64) * (CUBIC_BETA_DEN + CUBIC_BETA_NUM)
673                            / (2 * CUBIC_BETA_DEN)) as u32;
674                    } else {
675                        self.w_max = current_cwnd;
676                    }
677                    self.prev_w_max = current_cwnd;
678
679                    // Multiplicative decrease: cwnd = cwnd * beta
680                    // ssthresh = cwnd * beta = cwnd * 717 / 1024
681                    let new_cwnd = ((current_cwnd as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
682                    self.state.ssthresh = core::cmp::max(new_cwnd, 2 * MSS);
683                    self.state.cwnd = self.state.ssthresh + 3 * MSS;
684                    self.state.phase = CongestionPhase::FastRecovery;
685
686                    self.start_epoch();
687                }
688            }
689            CongestionPhase::FastRecovery => {
690                self.state.cwnd = self.state.cwnd.saturating_add(MSS);
691            }
692        }
693    }
694
695    fn on_timeout(&mut self) {
696        // Save w_max
697        self.prev_w_max = self.w_max;
698        self.w_max = self.state.cwnd;
699
700        // ssthresh = cwnd * beta
701        let new_ssthresh = ((self.state.cwnd as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
702        self.state.ssthresh = core::cmp::max(new_ssthresh, 2 * MSS);
703        self.state.cwnd = MSS;
704        self.state.dup_ack_count = 0;
705        self.state.phase = CongestionPhase::SlowStart;
706        self.state.rto = core::cmp::min(self.state.rto.saturating_mul(2), RTO_MAX_US);
707
708        // Reset epoch
709        self.epoch_start = 0;
710        self.elapsed_us = 0;
711    }
712
713    fn congestion_window(&self) -> u32 {
714        self.state.cwnd
715    }
716
717    fn slow_start_threshold(&self) -> u32 {
718        self.state.ssthresh
719    }
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725
726    #[test]
727    fn test_initial_state() {
728        let cc = RenoController::new();
729        assert_eq!(cc.congestion_window(), MSS);
730        assert_eq!(cc.slow_start_threshold(), u32::MAX);
731        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
732        assert_eq!(cc.dup_ack_count(), 0);
733        assert_eq!(cc.rto_us(), RTO_INITIAL_US);
734    }
735
736    #[test]
737    fn test_slow_start_growth() {
738        let mut cc = RenoController::new();
739        // In slow start, each ACK should increase cwnd by MSS
740        cc.on_ack(MSS, 50_000); // 50ms RTT
741        assert_eq!(cc.congestion_window(), 2 * MSS);
742        cc.on_ack(MSS, 50_000);
743        assert_eq!(cc.congestion_window(), 3 * MSS);
744        cc.on_ack(MSS, 50_000);
745        assert_eq!(cc.congestion_window(), 4 * MSS);
746        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
747    }
748
749    #[test]
750    fn test_slow_start_to_congestion_avoidance() {
751        let mut cc = RenoController::new();
752        // Set a low ssthresh to trigger transition
753        cc.state.ssthresh = 3 * MSS;
754        cc.on_ack(MSS, 50_000); // cwnd = 2 * MSS
755        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
756        cc.on_ack(MSS, 50_000); // cwnd = 3 * MSS >= ssthresh
757        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
758    }
759
760    #[test]
761    fn test_congestion_avoidance_linear_growth() {
762        let mut cc = RenoController::new();
763        // Force into congestion avoidance with cwnd = 4 * MSS
764        cc.state.cwnd = 4 * MSS;
765        cc.state.ssthresh = 4 * MSS;
766        cc.state.phase = CongestionPhase::CongestionAvoidance;
767
768        let initial_cwnd = cc.congestion_window();
769        // Each ACK should add approximately MSS * MSS / cwnd bytes
770        cc.on_ack(MSS, 50_000);
771        let increment = cc.congestion_window() - initial_cwnd;
772
773        // Expected increment: MSS * MSS / (4 * MSS) = MSS / 4 = 365
774        let expected = MSS / 4;
775        assert_eq!(increment, expected);
776    }
777
778    #[test]
779    fn test_congestion_avoidance_minimum_increment() {
780        let mut cc = RenoController::new();
781        // Very large cwnd to test minimum increment
782        cc.state.cwnd = u32::MAX / 2;
783        cc.state.ssthresh = MSS;
784        cc.state.phase = CongestionPhase::CongestionAvoidance;
785
786        let initial = cc.congestion_window();
787        cc.on_ack(MSS, 50_000);
788        // Should increase by at least 1 byte
789        assert!(cc.congestion_window() > initial);
790    }
791
792    #[test]
793    fn test_fast_retransmit_on_3_dup_acks() {
794        let mut cc = RenoController::new();
795        cc.state.cwnd = 10 * MSS;
796        cc.state.phase = CongestionPhase::CongestionAvoidance;
797
798        let original_cwnd = cc.congestion_window();
799
800        // 3 duplicate ACKs trigger fast retransmit / recovery
801        cc.on_duplicate_ack();
802        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
803        cc.on_duplicate_ack();
804        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
805        cc.on_duplicate_ack();
806        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
807
808        // ssthresh = max(cwnd / 2, 2 * MSS)
809        let expected_ssthresh = core::cmp::max(original_cwnd / 2, 2 * MSS);
810        assert_eq!(cc.slow_start_threshold(), expected_ssthresh);
811
812        // cwnd = ssthresh + 3 * MSS
813        assert_eq!(cc.congestion_window(), expected_ssthresh + 3 * MSS);
814    }
815
816    #[test]
817    fn test_fast_recovery_inflation() {
818        let mut cc = RenoController::new();
819        cc.state.cwnd = 10 * MSS;
820        cc.state.phase = CongestionPhase::CongestionAvoidance;
821
822        // Trigger fast recovery
823        for _ in 0..3 {
824            cc.on_duplicate_ack();
825        }
826        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
827        let cwnd_after_fr = cc.congestion_window();
828
829        // Additional dup ACKs inflate cwnd by MSS each
830        cc.on_duplicate_ack();
831        assert_eq!(cc.congestion_window(), cwnd_after_fr + MSS);
832        cc.on_duplicate_ack();
833        assert_eq!(cc.congestion_window(), cwnd_after_fr + 2 * MSS);
834    }
835
836    #[test]
837    fn test_fast_recovery_exit_on_new_ack() {
838        let mut cc = RenoController::new();
839        cc.state.cwnd = 10 * MSS;
840        cc.state.phase = CongestionPhase::CongestionAvoidance;
841
842        // Enter fast recovery
843        for _ in 0..3 {
844            cc.on_duplicate_ack();
845        }
846        let ssthresh = cc.slow_start_threshold();
847
848        // New ACK should deflate cwnd to ssthresh and enter congestion avoidance
849        cc.on_ack(MSS, 50_000);
850        assert_eq!(cc.congestion_window(), ssthresh);
851        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
852        assert_eq!(cc.dup_ack_count(), 0);
853    }
854
855    #[test]
856    fn test_timeout_resets_to_slow_start() {
857        let mut cc = RenoController::new();
858        cc.state.cwnd = 20 * MSS;
859        cc.state.ssthresh = 15 * MSS;
860        cc.state.phase = CongestionPhase::CongestionAvoidance;
861
862        let cwnd_before = cc.congestion_window();
863        cc.on_timeout();
864
865        // cwnd should reset to 1 MSS
866        assert_eq!(cc.congestion_window(), MSS);
867        // ssthresh = max(cwnd / 2, 2 * MSS)
868        assert_eq!(
869            cc.slow_start_threshold(),
870            core::cmp::max(cwnd_before / 2, 2 * MSS)
871        );
872        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
873    }
874
875    #[test]
876    fn test_timeout_doubles_rto() {
877        let mut cc = RenoController::new();
878        let rto_initial = cc.rto_us();
879
880        cc.on_timeout();
881        assert_eq!(cc.rto_us(), rto_initial * 2);
882
883        cc.on_timeout();
884        assert_eq!(cc.rto_us(), rto_initial * 4);
885    }
886
887    #[test]
888    fn test_rto_clamped_to_max() {
889        let mut cc = RenoController::new();
890        // Force RTO near max
891        cc.state.rto = RTO_MAX_US;
892        cc.on_timeout();
893        assert_eq!(cc.rto_us(), RTO_MAX_US);
894    }
895
896    #[test]
897    fn test_rto_first_rtt_sample() {
898        let mut cc = RenoController::new();
899        let rtt = 100_000u64; // 100ms
900
901        cc.on_ack(MSS, rtt);
902
903        // After first sample: SRTT = R, RTTVAR = R/2
904        assert_eq!(cc.srtt_us(), rtt);
905        assert_eq!(cc.rttvar_us(), rtt / 2);
906
907        // RTO = SRTT + max(G, 4 * RTTVAR) = 100ms + max(1ms, 4 * 50ms) = 100ms + 200ms
908        // = 300ms
909        let expected_rto = rtt + 4 * (rtt / 2);
910        // Clamp to at least RTO_MIN
911        let expected_rto = expected_rto.clamp(RTO_MIN_US, RTO_MAX_US);
912        assert_eq!(cc.rto_us(), expected_rto);
913    }
914
915    #[test]
916    fn test_rto_subsequent_rtt_sample() {
917        let mut cc = RenoController::new();
918
919        // First sample: 100ms
920        cc.on_ack(MSS, 100_000);
921        let srtt_after_first = cc.srtt_us();
922        assert_eq!(srtt_after_first, 100_000);
923
924        // Second sample: 120ms
925        cc.on_ack(MSS, 120_000);
926        let srtt_after_second = cc.srtt_us();
927
928        // SRTT should move toward 120ms but not reach it
929        // SRTT = SRTT - SRTT/8 + R/8 = 100000 - 12500 + 15000 = 102500
930        assert_eq!(srtt_after_second, 102_500);
931    }
932
933    #[test]
934    fn test_rto_minimum_enforced() {
935        let mut cc = RenoController::new();
936        // Very small RTT should still result in RTO >= RTO_MIN
937        cc.on_ack(MSS, 100); // 0.1ms
938        assert!(cc.rto_us() >= RTO_MIN_US);
939    }
940
941    #[test]
942    fn test_zero_rtt_ignored() {
943        let mut cc = RenoController::new();
944        let rto_before = cc.rto_us();
945        cc.on_ack(MSS, 0);
946        // RTT of 0 should not update RTT estimates; RTO stays at initial
947        assert_eq!(cc.rto_us(), rto_before);
948    }
949
950    #[test]
951    fn test_ssthresh_minimum_2mss() {
952        let mut cc = RenoController::new();
953        // Very small cwnd: cwnd / 2 < 2 * MSS
954        cc.state.cwnd = MSS;
955        cc.on_timeout();
956        assert_eq!(cc.slow_start_threshold(), 2 * MSS);
957    }
958
959    #[test]
960    fn test_dup_ack_count_reset_on_new_ack() {
961        let mut cc = RenoController::new();
962        cc.on_duplicate_ack();
963        cc.on_duplicate_ack();
964        assert_eq!(cc.dup_ack_count(), 2);
965
966        // New ACK resets dup count
967        cc.on_ack(MSS, 50_000);
968        assert_eq!(cc.dup_ack_count(), 0);
969    }
970
971    #[test]
972    fn test_full_congestion_cycle() {
973        let mut cc = RenoController::new();
974
975        // Phase 1: Slow start from 1 MSS
976        for _ in 0..5 {
977            cc.on_ack(MSS, 50_000);
978        }
979        assert_eq!(cc.congestion_window(), 6 * MSS);
980        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
981
982        // Phase 2: Timeout -- reset to slow start
983        cc.on_timeout();
984        assert_eq!(cc.congestion_window(), MSS);
985        assert_eq!(
986            cc.slow_start_threshold(),
987            core::cmp::max(6 * MSS / 2, 2 * MSS)
988        );
989        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
990
991        // Phase 3: Grow back past ssthresh into congestion avoidance
992        let ssthresh = cc.slow_start_threshold();
993        while cc.phase() == CongestionPhase::SlowStart {
994            cc.on_ack(MSS, 50_000);
995        }
996        assert!(cc.congestion_window() >= ssthresh);
997        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
998
999        // Phase 4: 3 dup ACKs -> fast recovery
1000        for _ in 0..3 {
1001            cc.on_duplicate_ack();
1002        }
1003        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1004
1005        // Phase 5: New ACK exits fast recovery
1006        cc.on_ack(MSS, 50_000);
1007        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1008    }
1009
1010    // -----------------------------------------------------------------------
1011    // Cubic controller tests
1012    // -----------------------------------------------------------------------
1013
1014    #[test]
1015    fn test_integer_cbrt_exact_cubes() {
1016        assert_eq!(integer_cbrt(0), 0);
1017        assert_eq!(integer_cbrt(1), 1);
1018        assert_eq!(integer_cbrt(8), 2);
1019        assert_eq!(integer_cbrt(27), 3);
1020        assert_eq!(integer_cbrt(64), 4);
1021        assert_eq!(integer_cbrt(125), 5);
1022        assert_eq!(integer_cbrt(1000), 10);
1023        assert_eq!(integer_cbrt(1_000_000), 100);
1024        assert_eq!(integer_cbrt(1_000_000_000), 1000);
1025    }
1026
1027    #[test]
1028    fn test_integer_cbrt_non_exact() {
1029        // floor of cube root
1030        assert_eq!(integer_cbrt(2), 1);
1031        assert_eq!(integer_cbrt(7), 1);
1032        assert_eq!(integer_cbrt(9), 2);
1033        assert_eq!(integer_cbrt(26), 2);
1034        assert_eq!(integer_cbrt(63), 3);
1035        assert_eq!(integer_cbrt(100), 4);
1036        // Verify: 4^3 = 64 <= 100 < 125 = 5^3
1037    }
1038
1039    #[test]
1040    fn test_integer_cbrt_large_values() {
1041        // Typical cwnd-related values
1042        let val = 1_000_000_000_000u64; // 10^12
1043        let root = integer_cbrt(val);
1044        assert_eq!(root, 10_000); // 10000^3 = 10^12
1045
1046        // Maximum-ish value
1047        let root = integer_cbrt(u64::MAX);
1048        // 2642245^3 ≈ 1.844 * 10^19 ≈ u64::MAX
1049        assert!(root >= 2_642_245);
1050        // Verify root^3 <= u64::MAX (doesn't overflow)
1051        assert!(root
1052            .checked_mul(root)
1053            .and_then(|r2| r2.checked_mul(root))
1054            .is_some());
1055        // Verify (root+1)^3 overflows u64 (proves root is the floor cbrt)
1056        let r1 = root + 1;
1057        assert!(r1
1058            .checked_mul(r1)
1059            .and_then(|r2| r2.checked_mul(r1))
1060            .is_none());
1061    }
1062
1063    #[test]
1064    fn test_cubic_initial_state() {
1065        let cc = CubicController::new();
1066        assert_eq!(cc.congestion_window(), MSS);
1067        assert_eq!(cc.slow_start_threshold(), u32::MAX);
1068        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1069        assert_eq!(cc.w_max, 0);
1070    }
1071
1072    #[test]
1073    fn test_cubic_slow_start_identical_to_reno() {
1074        let mut cubic = CubicController::new();
1075        let mut reno = RenoController::new();
1076
1077        // Both should grow identically during slow start
1078        for _ in 0..5 {
1079            cubic.on_ack(MSS, 50_000);
1080            reno.on_ack(MSS, 50_000);
1081        }
1082        assert_eq!(cubic.congestion_window(), reno.congestion_window());
1083        assert_eq!(cubic.phase(), CongestionPhase::SlowStart);
1084    }
1085
1086    #[test]
1087    fn test_cubic_loss_response_beta_07() {
1088        let mut cc = CubicController::new();
1089        // Set up: large cwnd in congestion avoidance
1090        cc.state.cwnd = 100 * MSS;
1091        cc.state.ssthresh = 50 * MSS;
1092        cc.state.phase = CongestionPhase::CongestionAvoidance;
1093        cc.w_max = 100 * MSS;
1094
1095        let cwnd_before = cc.congestion_window();
1096
1097        // 3 dup ACKs trigger loss
1098        for _ in 0..3 {
1099            cc.on_duplicate_ack();
1100        }
1101
1102        // ssthresh should be cwnd * 0.7 (beta)
1103        let expected_ssthresh = ((cwnd_before as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
1104        assert_eq!(cc.slow_start_threshold(), expected_ssthresh);
1105
1106        // w_max should record the pre-loss cwnd
1107        assert_eq!(cc.w_max, cwnd_before);
1108    }
1109
1110    #[test]
1111    fn test_cubic_fast_convergence() {
1112        let mut cc = CubicController::new();
1113        cc.state.cwnd = 100 * MSS;
1114        cc.state.ssthresh = 50 * MSS;
1115        cc.state.phase = CongestionPhase::CongestionAvoidance;
1116        cc.prev_w_max = 120 * MSS; // Previous w_max was higher
1117
1118        let cwnd_before = cc.congestion_window();
1119
1120        // Trigger loss: cwnd (100) < prev_w_max (120), so fast convergence applies
1121        for _ in 0..3 {
1122            cc.on_duplicate_ack();
1123        }
1124
1125        // w_max should be reduced: cwnd * (1 + beta) / 2 = 100 * 1.7 / 2 = 85 segments
1126        let expected_w_max = ((cwnd_before as u64) * (CUBIC_BETA_DEN + CUBIC_BETA_NUM)
1127            / (2 * CUBIC_BETA_DEN)) as u32;
1128        assert_eq!(cc.w_max, expected_w_max);
1129        // Fast convergence w_max should be less than normal w_max
1130        assert!(cc.w_max < cwnd_before);
1131    }
1132
1133    #[test]
1134    fn test_cubic_timeout_resets() {
1135        let mut cc = CubicController::new();
1136        cc.state.cwnd = 50 * MSS;
1137        cc.state.phase = CongestionPhase::CongestionAvoidance;
1138
1139        cc.on_timeout();
1140
1141        assert_eq!(cc.congestion_window(), MSS);
1142        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1143        // ssthresh = cwnd * beta = 50 * 717 / 1024 ~= 35 segments
1144        let expected = ((50 * MSS as u64) * CUBIC_BETA_NUM / CUBIC_BETA_DEN) as u32;
1145        assert_eq!(cc.slow_start_threshold(), expected);
1146    }
1147
1148    #[test]
1149    fn test_cubic_growth_after_loss() {
1150        let mut cc = CubicController::new();
1151        // Initialize RTT
1152        cc.state.cwnd = 100 * MSS;
1153        cc.state.ssthresh = 50 * MSS;
1154        cc.state.phase = CongestionPhase::CongestionAvoidance;
1155
1156        // Prime RTT estimate
1157        cc.state.update_rtt(50_000); // 50ms
1158
1159        // Trigger loss
1160        for _ in 0..3 {
1161            cc.on_duplicate_ack();
1162        }
1163        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1164
1165        // Exit fast recovery with new ACK
1166        cc.on_ack(MSS, 50_000);
1167        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1168
1169        let cwnd_after_loss = cc.congestion_window();
1170
1171        // Continue with ACKs -- cwnd should grow
1172        for _ in 0..20 {
1173            cc.on_ack(MSS, 50_000);
1174        }
1175        assert!(
1176            cc.congestion_window() > cwnd_after_loss,
1177            "cwnd should grow after loss: {} vs {}",
1178            cc.congestion_window(),
1179            cwnd_after_loss
1180        );
1181    }
1182
1183    #[test]
1184    fn test_cubic_tcp_friendly_region() {
1185        // In the TCP-friendly region (early after loss with small elapsed time),
1186        // Cubic should behave at least as well as Reno.
1187        let mut cc = CubicController::new();
1188        cc.state.update_rtt(50_000);
1189        cc.state.cwnd = 10 * MSS;
1190        cc.state.ssthresh = 10 * MSS;
1191        cc.state.phase = CongestionPhase::CongestionAvoidance;
1192        cc.w_max = 10 * MSS;
1193        cc.origin_point = 10;
1194        cc.k_us = cc.compute_k(10);
1195        cc.elapsed_us = 0;
1196
1197        // TCP-friendly window should be at least 1 MSS
1198        let w_tcp = cc.tcp_friendly_window(0);
1199        assert!(w_tcp >= MSS);
1200
1201        // After some time, TCP-friendly should grow
1202        let w_tcp_later = cc.tcp_friendly_window(500_000); // 500ms
1203        assert!(w_tcp_later >= w_tcp);
1204    }
1205
1206    #[test]
1207    fn test_cubic_window_concave_then_convex() {
1208        // The cubic function should be concave (below w_max) before K
1209        // and convex (above w_max) after K.
1210        let mut cc = CubicController::new();
1211        cc.state.update_rtt(50_000);
1212        cc.w_max = 100 * MSS;
1213        cc.origin_point = 100;
1214        cc.k_us = cc.compute_k(100);
1215
1216        // Well before K: window should be below w_max
1217        if cc.k_us > 200_000 {
1218            let w_early = cc.cubic_window(cc.k_us / 4);
1219            assert!(
1220                w_early < cc.w_max,
1221                "early window {} should be below w_max {}",
1222                w_early,
1223                cc.w_max
1224            );
1225        }
1226
1227        // At K: window should be approximately w_max
1228        let w_at_k = cc.cubic_window(cc.k_us);
1229        let tolerance = 5 * MSS; // Allow some fixed-point rounding
1230        let diff = if w_at_k > cc.w_max {
1231            w_at_k - cc.w_max
1232        } else {
1233            cc.w_max - w_at_k
1234        };
1235        assert!(
1236            diff <= tolerance,
1237            "window at K ({}) should be close to w_max ({}), diff={}",
1238            w_at_k,
1239            cc.w_max,
1240            diff
1241        );
1242
1243        // Well after K: window should exceed w_max
1244        let w_late = cc.cubic_window(cc.k_us * 3);
1245        assert!(
1246            w_late > cc.w_max,
1247            "late window {} should exceed w_max {}",
1248            w_late,
1249            cc.w_max
1250        );
1251    }
1252
1253    #[test]
1254    fn test_cubic_full_congestion_cycle() {
1255        let mut cc = CubicController::new();
1256
1257        // Phase 1: Slow start
1258        for _ in 0..5 {
1259            cc.on_ack(MSS, 50_000);
1260        }
1261        assert_eq!(cc.congestion_window(), 6 * MSS);
1262        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1263
1264        // Phase 2: Timeout
1265        cc.on_timeout();
1266        assert_eq!(cc.congestion_window(), MSS);
1267        assert_eq!(cc.phase(), CongestionPhase::SlowStart);
1268        assert!(cc.w_max > 0);
1269
1270        // Phase 3: Grow through slow start into CA
1271        while cc.phase() == CongestionPhase::SlowStart {
1272            cc.on_ack(MSS, 50_000);
1273        }
1274        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1275
1276        // Phase 4: Grow in CA
1277        let cwnd_at_ca = cc.congestion_window();
1278        for _ in 0..10 {
1279            cc.on_ack(MSS, 50_000);
1280        }
1281        assert!(cc.congestion_window() >= cwnd_at_ca);
1282
1283        // Phase 5: 3 dup ACKs -> fast recovery
1284        for _ in 0..3 {
1285            cc.on_duplicate_ack();
1286        }
1287        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1288
1289        // Phase 6: Exit fast recovery
1290        cc.on_ack(MSS, 50_000);
1291        assert_eq!(cc.phase(), CongestionPhase::CongestionAvoidance);
1292    }
1293
1294    #[test]
1295    fn test_cubic_recovery_inflation() {
1296        let mut cc = CubicController::new();
1297        cc.state.cwnd = 20 * MSS;
1298        cc.state.phase = CongestionPhase::CongestionAvoidance;
1299
1300        // Trigger fast recovery
1301        for _ in 0..3 {
1302            cc.on_duplicate_ack();
1303        }
1304        assert_eq!(cc.phase(), CongestionPhase::FastRecovery);
1305        let cwnd_fr = cc.congestion_window();
1306
1307        // Additional dup ACKs should inflate cwnd by MSS
1308        cc.on_duplicate_ack();
1309        assert_eq!(cc.congestion_window(), cwnd_fr + MSS);
1310        cc.on_duplicate_ack();
1311        assert_eq!(cc.congestion_window(), cwnd_fr + 2 * MSS);
1312    }
1313}