⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
ntp.rs

1//! NTP client for VeridianOS time synchronization
2//!
3//! Implements NTPv4 (RFC 5905) client mode with:
4//! - 48-byte NTP packet serialization/deserialization
5//! - Clock offset and round-trip delay calculation (integer-only)
6//! - Marzullo's algorithm for multi-source time selection
7//! - Clock filter (best-of-8 by minimum delay)
8//! - Drift compensation via integer linear regression
9//! - RTC integration for system clock adjustment
10
11#![allow(dead_code)]
12
13#[cfg(feature = "alloc")]
14use alloc::vec::Vec;
15use core::sync::atomic::{AtomicI64, AtomicU64, Ordering};
16
17// ---------------------------------------------------------------------------
18// NTP constants
19// ---------------------------------------------------------------------------
20
21/// NTP UDP port
22pub const NTP_PORT: u16 = 123;
23
24/// NTP packet size in bytes
25pub const NTP_PACKET_SIZE: usize = 48;
26
27/// Seconds between NTP epoch (1900-01-01) and Unix epoch (1970-01-01)
28/// 70 years including 17 leap years: (70*365 + 17) * 86400
29pub const NTP_UNIX_OFFSET: u64 = 2_208_988_800;
30
31/// Minimum poll interval in seconds (2^6 = 64s)
32pub const MIN_POLL_INTERVAL: u32 = 64;
33
34/// Maximum poll interval in seconds (2^10 = 1024s)
35pub const MAX_POLL_INTERVAL: u32 = 1024;
36
37/// Number of samples in the clock filter
38pub const CLOCK_FILTER_SIZE: usize = 8;
39
40/// Maximum allowed round-trip delay in milliseconds (reject outliers)
41const MAX_DELAY_MS: i64 = 5000;
42
43/// Step threshold in milliseconds (offsets above this trigger a step
44/// correction)
45const STEP_THRESHOLD_MS: i64 = 128;
46
47/// NTPv4 version number
48const NTP_VERSION: u8 = 4;
49
50/// Client mode
51const MODE_CLIENT: u8 = 3;
52
53/// Server mode
54const MODE_SERVER: u8 = 4;
55
56// ---------------------------------------------------------------------------
57// Leap Indicator
58// ---------------------------------------------------------------------------
59
60/// Leap indicator values (2-bit field)
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62#[repr(u8)]
63pub enum LeapIndicator {
64    /// No warning
65    NoWarning = 0,
66    /// Last minute of the day has 61 seconds
67    AddSecond = 1,
68    /// Last minute of the day has 59 seconds
69    DeleteSecond = 2,
70    /// Clock not synchronized (alarm)
71    Unsynchronized = 3,
72}
73
74impl LeapIndicator {
75    fn from_u8(val: u8) -> Self {
76        match val & 0x03 {
77            0 => Self::NoWarning,
78            1 => Self::AddSecond,
79            2 => Self::DeleteSecond,
80            _ => Self::Unsynchronized,
81        }
82    }
83}
84
85// ---------------------------------------------------------------------------
86// NTP Timestamp (32.32 fixed-point)
87// ---------------------------------------------------------------------------
88
89/// NTP timestamp: 32-bit seconds since 1900-01-01 + 32-bit fraction.
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
91pub struct NtpTimestamp {
92    /// Seconds since 1900-01-01
93    pub seconds: u32,
94    /// Fractional seconds (2^-32 second units)
95    pub fraction: u32,
96}
97
98impl NtpTimestamp {
99    /// Create a new NTP timestamp.
100    pub fn new(seconds: u32, fraction: u32) -> Self {
101        Self { seconds, fraction }
102    }
103
104    /// Create from Unix epoch seconds + milliseconds.
105    pub fn from_unix(epoch_secs: u64, millis: u32) -> Self {
106        let ntp_secs = epoch_secs.saturating_add(NTP_UNIX_OFFSET);
107        // Convert milliseconds to NTP fraction: frac = millis * 2^32 / 1000
108        let fraction = ((millis as u64) * 4_294_967_296 / 1000) as u32;
109        Self {
110            seconds: ntp_secs as u32,
111            fraction,
112        }
113    }
114
115    /// Convert to Unix epoch seconds (truncated).
116    pub fn to_unix_secs(&self) -> u64 {
117        (self.seconds as u64).saturating_sub(NTP_UNIX_OFFSET)
118    }
119
120    /// Convert to milliseconds since NTP epoch.
121    pub fn to_millis(&self) -> u64 {
122        let secs_ms = (self.seconds as u64) * 1000;
123        // fraction / 2^32 * 1000 = fraction * 1000 / 2^32
124        let frac_ms = (self.fraction as u64) * 1000 / 4_294_967_296;
125        secs_ms.saturating_add(frac_ms)
126    }
127
128    /// Serialize to 8 big-endian bytes.
129    pub fn to_bytes(&self) -> [u8; 8] {
130        let mut buf = [0u8; 8];
131        buf[0..4].copy_from_slice(&self.seconds.to_be_bytes());
132        buf[4..8].copy_from_slice(&self.fraction.to_be_bytes());
133        buf
134    }
135
136    /// Deserialize from 8 big-endian bytes.
137    pub fn from_bytes(bytes: &[u8; 8]) -> Self {
138        Self {
139            seconds: u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
140            fraction: u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
141        }
142    }
143}
144
145// ---------------------------------------------------------------------------
146// NTP Packet (48 bytes)
147// ---------------------------------------------------------------------------
148
149/// NTPv4 packet (48 bytes, RFC 5905).
150#[derive(Debug, Clone, PartialEq, Eq)]
151pub struct NtpPacket {
152    /// Leap indicator (2 bits)
153    pub leap: LeapIndicator,
154    /// Version number (3 bits)
155    pub version: u8,
156    /// Mode (3 bits)
157    pub mode: u8,
158    /// Stratum level (0 = unspecified, 1 = primary, 2-15 = secondary)
159    pub stratum: u8,
160    /// Maximum interval between successive messages (log2 seconds)
161    pub poll: i8,
162    /// Precision of the system clock (log2 seconds)
163    pub precision: i8,
164    /// Total round-trip delay to the reference source (NTP short format, 16.16)
165    pub root_delay: u32,
166    /// Maximum error relative to the reference source (NTP short format, 16.16)
167    pub root_dispersion: u32,
168    /// Reference identifier (stratum 1: 4-char ASCII, else: IP address)
169    pub reference_id: [u8; 4],
170    /// Time when the system clock was last set
171    pub reference_ts: NtpTimestamp,
172    /// Time at the client when the request departed
173    pub origin_ts: NtpTimestamp,
174    /// Time at the server when the request arrived
175    pub receive_ts: NtpTimestamp,
176    /// Time at the server when the response departed
177    pub transmit_ts: NtpTimestamp,
178}
179
180impl NtpPacket {
181    /// Create a new client request packet.
182    pub fn new_request(transmit_ts: NtpTimestamp) -> Self {
183        Self {
184            leap: LeapIndicator::NoWarning,
185            version: NTP_VERSION,
186            mode: MODE_CLIENT,
187            stratum: 0,
188            poll: 6,        // 2^6 = 64s
189            precision: -18, // ~3.8 microseconds
190            root_delay: 0,
191            root_dispersion: 0,
192            reference_id: [0; 4],
193            reference_ts: NtpTimestamp::default(),
194            origin_ts: NtpTimestamp::default(),
195            receive_ts: NtpTimestamp::default(),
196            transmit_ts,
197        }
198    }
199
200    /// Serialize to a 48-byte array.
201    pub fn to_bytes(&self) -> [u8; NTP_PACKET_SIZE] {
202        let mut buf = [0u8; NTP_PACKET_SIZE];
203        // Byte 0: LI (2) | VN (3) | Mode (3)
204        buf[0] = ((self.leap as u8) << 6) | ((self.version & 0x07) << 3) | (self.mode & 0x07);
205        buf[1] = self.stratum;
206        buf[2] = self.poll as u8;
207        buf[3] = self.precision as u8;
208        buf[4..8].copy_from_slice(&self.root_delay.to_be_bytes());
209        buf[8..12].copy_from_slice(&self.root_dispersion.to_be_bytes());
210        buf[12..16].copy_from_slice(&self.reference_id);
211        buf[16..24].copy_from_slice(&self.reference_ts.to_bytes());
212        buf[24..32].copy_from_slice(&self.origin_ts.to_bytes());
213        buf[32..40].copy_from_slice(&self.receive_ts.to_bytes());
214        buf[40..48].copy_from_slice(&self.transmit_ts.to_bytes());
215        buf
216    }
217
218    /// Deserialize from a 48-byte slice.
219    pub fn from_bytes(buf: &[u8]) -> Option<Self> {
220        if buf.len() < NTP_PACKET_SIZE {
221            return None;
222        }
223        let li = (buf[0] >> 6) & 0x03;
224        let vn = (buf[0] >> 3) & 0x07;
225        let mode = buf[0] & 0x07;
226
227        Some(Self {
228            leap: LeapIndicator::from_u8(li),
229            version: vn,
230            mode,
231            stratum: buf[1],
232            poll: buf[2] as i8,
233            precision: buf[3] as i8,
234            root_delay: u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]),
235            root_dispersion: u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]),
236            reference_id: [buf[12], buf[13], buf[14], buf[15]],
237            reference_ts: NtpTimestamp::from_bytes(buf[16..24].try_into().ok()?),
238            origin_ts: NtpTimestamp::from_bytes(buf[24..32].try_into().ok()?),
239            receive_ts: NtpTimestamp::from_bytes(buf[32..40].try_into().ok()?),
240            transmit_ts: NtpTimestamp::from_bytes(buf[40..48].try_into().ok()?),
241        })
242    }
243
244    /// Check if this is a Kiss-o'-Death (KoD) packet.
245    ///
246    /// KoD packets have stratum == 0 and a 4-char ASCII code in reference_id.
247    pub fn is_kod(&self) -> bool {
248        self.stratum == 0
249            && self
250                .reference_id
251                .iter()
252                .all(|&b| (0x20..=0x7E).contains(&b))
253    }
254
255    /// Get the KoD code (e.g., "DENY", "RATE", "RSTR") if this is a KoD packet.
256    pub fn kod_code(&self) -> Option<[u8; 4]> {
257        if self.is_kod() {
258            Some(self.reference_id)
259        } else {
260            None
261        }
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Clock offset / delay calculation
267// ---------------------------------------------------------------------------
268
269/// Result of processing an NTP response.
270#[derive(Debug, Clone, Copy, PartialEq, Eq)]
271pub struct NtpMeasurement {
272    /// Clock offset in milliseconds (positive = local behind server)
273    pub offset_ms: i64,
274    /// Round-trip delay in milliseconds
275    pub delay_ms: i64,
276}
277
278/// Calculate clock offset and round-trip delay from NTP timestamps.
279///
280/// Using the standard NTP formulas (all in milliseconds):
281/// - offset = ((t2 - t1) + (t3 - t4)) / 2
282/// - delay  = (t4 - t1) - (t3 - t2)
283///
284/// Where:
285/// - t1 = client transmit time (origin)
286/// - t2 = server receive time
287/// - t3 = server transmit time
288/// - t4 = client receive time
289pub fn calculate_offset_delay(
290    t1: &NtpTimestamp,
291    t2: &NtpTimestamp,
292    t3: &NtpTimestamp,
293    t4: &NtpTimestamp,
294) -> NtpMeasurement {
295    let t1_ms = t1.to_millis() as i64;
296    let t2_ms = t2.to_millis() as i64;
297    let t3_ms = t3.to_millis() as i64;
298    let t4_ms = t4.to_millis() as i64;
299
300    // offset = ((t2 - t1) + (t3 - t4)) / 2
301    let offset_ms = ((t2_ms - t1_ms) + (t3_ms - t4_ms)) / 2;
302
303    // delay = (t4 - t1) - (t3 - t2)
304    let delay_ms = (t4_ms - t1_ms) - (t3_ms - t2_ms);
305
306    NtpMeasurement {
307        offset_ms,
308        delay_ms,
309    }
310}
311
312// ---------------------------------------------------------------------------
313// Clock Filter (best of N samples by minimum delay)
314// ---------------------------------------------------------------------------
315
316/// Clock filter: keeps the last CLOCK_FILTER_SIZE samples and selects the
317/// one with the smallest round-trip delay as the best estimate.
318#[derive(Debug)]
319pub struct ClockFilter {
320    samples: [Option<NtpMeasurement>; CLOCK_FILTER_SIZE],
321    next_idx: usize,
322    count: usize,
323}
324
325impl Default for ClockFilter {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331impl ClockFilter {
332    /// Create a new empty clock filter.
333    pub fn new() -> Self {
334        Self {
335            samples: [None; CLOCK_FILTER_SIZE],
336            next_idx: 0,
337            count: 0,
338        }
339    }
340
341    /// Add a new measurement to the filter.
342    pub fn add_sample(&mut self, sample: NtpMeasurement) {
343        self.samples[self.next_idx] = Some(sample);
344        self.next_idx = (self.next_idx + 1) % CLOCK_FILTER_SIZE;
345        if self.count < CLOCK_FILTER_SIZE {
346            self.count += 1;
347        }
348    }
349
350    /// Select the best sample (lowest absolute delay).
351    pub fn best_sample(&self) -> Option<NtpMeasurement> {
352        let mut best: Option<NtpMeasurement> = None;
353        for s in self.samples.iter().flatten() {
354            match best {
355                None => best = Some(*s),
356                Some(b) if s.delay_ms.unsigned_abs() < b.delay_ms.unsigned_abs() => {
357                    best = Some(*s);
358                }
359                _ => {}
360            }
361        }
362        best
363    }
364
365    /// Compute jitter as root-mean-square of successive offset differences
366    /// (integer square root approximation).
367    pub fn jitter_ms(&self) -> u64 {
368        if self.count < 2 {
369            return 0;
370        }
371
372        // Collect valid samples in order
373        let mut offsets = [0i64; CLOCK_FILTER_SIZE];
374        let mut n = 0usize;
375        for s in self.samples.iter().flatten() {
376            if n < CLOCK_FILTER_SIZE {
377                offsets[n] = s.offset_ms;
378                n += 1;
379            }
380        }
381
382        if n < 2 {
383            return 0;
384        }
385
386        // Sum of squared differences
387        let mut sum_sq: u64 = 0;
388        for i in 1..n {
389            let diff = offsets[i] - offsets[i - 1];
390            // Use checked arithmetic to avoid overflow
391            let sq = (diff as i128) * (diff as i128);
392            sum_sq = sum_sq.saturating_add(sq as u64);
393        }
394
395        let mean_sq = sum_sq / (n as u64 - 1);
396        // Integer square root via Newton's method
397        isqrt(mean_sq)
398    }
399
400    /// Number of valid samples in the filter.
401    pub fn sample_count(&self) -> usize {
402        self.count
403    }
404}
405
406/// Integer square root using Newton's method.
407fn isqrt(n: u64) -> u64 {
408    if n == 0 {
409        return 0;
410    }
411    let mut x = n;
412    let mut y = x.div_ceil(2);
413    while y < x {
414        x = y;
415        y = (x + n / x) / 2;
416    }
417    x
418}
419
420// ---------------------------------------------------------------------------
421// Marzullo's Algorithm
422// ---------------------------------------------------------------------------
423
424/// An interval from a time source for Marzullo's algorithm.
425#[derive(Debug, Clone, Copy)]
426pub struct TimeInterval {
427    /// Lower bound (offset - delay/2) in milliseconds
428    pub low: i64,
429    /// Upper bound (offset + delay/2) in milliseconds
430    pub high: i64,
431}
432
433/// Run Marzullo's algorithm on a set of time intervals to find the
434/// tightest intersection supported by a majority of sources.
435///
436/// Returns the best offset estimate (midpoint of the tightest interval)
437/// or None if no majority intersection exists.
438#[cfg(feature = "alloc")]
439pub fn marzullo_select(intervals: &[TimeInterval]) -> Option<i64> {
440    if intervals.is_empty() {
441        return None;
442    }
443    if intervals.len() == 1 {
444        return Some((intervals[0].low + intervals[0].high) / 2);
445    }
446
447    // Build sorted endpoint list: (value, type)
448    // type: -1 for interval start, +1 for interval end
449    let mut endpoints: Vec<(i64, i32)> = Vec::with_capacity(intervals.len() * 2);
450    for iv in intervals {
451        endpoints.push((iv.low, -1));
452        endpoints.push((iv.high, 1));
453    }
454    // Sort by value, then by type (starts before ends at same point)
455    endpoints.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
456
457    let n = intervals.len() as i32;
458    let majority = (n + 1) / 2; // ceil(n/2)
459
460    let mut best_low = i64::MIN;
461    let mut best_high = i64::MAX;
462    let mut best_count = 0i32;
463    let mut count = 0i32;
464
465    for &(val, kind) in &endpoints {
466        // -1 means entering an interval, +1 means leaving
467        count -= kind; // entering: count += 1, leaving: count -= 1
468        if count >= majority && count > best_count {
469            best_count = count;
470            best_low = val;
471        }
472        if count < best_count && best_count >= majority {
473            best_high = val;
474            break;
475        }
476    }
477
478    if best_count >= majority && best_low != i64::MIN && best_high != i64::MAX {
479        Some((best_low + best_high) / 2)
480    } else {
481        // Fallback: use interval with smallest delay (closest to true time)
482        let mut min_span = i64::MAX;
483        let mut best_mid = 0i64;
484        for iv in intervals {
485            let span = iv.high - iv.low;
486            if span < min_span {
487                min_span = span;
488                best_mid = (iv.low + iv.high) / 2;
489            }
490        }
491        Some(best_mid)
492    }
493}
494
495// ---------------------------------------------------------------------------
496// Drift compensation (integer linear regression)
497// ---------------------------------------------------------------------------
498
499/// Drift estimator using integer linear regression over recent measurements.
500///
501/// Tracks (time, offset) pairs and computes drift rate in parts-per-million
502/// (PPM) using integer arithmetic only.
503#[derive(Debug)]
504pub struct DriftEstimator {
505    /// Ring buffer of (elapsed_ms_since_start, offset_ms) pairs
506    samples: [(u64, i64); CLOCK_FILTER_SIZE],
507    count: usize,
508    next_idx: usize,
509}
510
511impl Default for DriftEstimator {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517impl DriftEstimator {
518    /// Create a new drift estimator.
519    pub fn new() -> Self {
520        Self {
521            samples: [(0, 0); CLOCK_FILTER_SIZE],
522            count: 0,
523            next_idx: 0,
524        }
525    }
526
527    /// Add a data point: elapsed time in ms since first measurement, and offset
528    /// in ms.
529    pub fn add(&mut self, elapsed_ms: u64, offset_ms: i64) {
530        self.samples[self.next_idx] = (elapsed_ms, offset_ms);
531        self.next_idx = (self.next_idx + 1) % CLOCK_FILTER_SIZE;
532        if self.count < CLOCK_FILTER_SIZE {
533            self.count += 1;
534        }
535    }
536
537    /// Compute drift rate in parts-per-billion (PPB) using integer-only
538    /// linear regression (slope = sum(dx*dy) / sum(dx^2)).
539    ///
540    /// Returns None if fewer than 2 samples.
541    pub fn drift_ppb(&self) -> Option<i64> {
542        if self.count < 2 {
543            return None;
544        }
545
546        // Compute means (integer approximation)
547        let mut sum_x: i128 = 0;
548        let mut sum_y: i128 = 0;
549        for i in 0..self.count {
550            sum_x += self.samples[i].0 as i128;
551            sum_y += self.samples[i].1 as i128;
552        }
553        let n = self.count as i128;
554        let mean_x = sum_x / n;
555        let mean_y = sum_y / n;
556
557        // Linear regression: slope = sum((xi - mean_x)(yi - mean_y)) / sum((xi -
558        // mean_x)^2)
559        let mut num: i128 = 0;
560        let mut den: i128 = 0;
561        for i in 0..self.count {
562            let dx = self.samples[i].0 as i128 - mean_x;
563            let dy = self.samples[i].1 as i128 - mean_y;
564            num += dx * dy;
565            den += dx * dx;
566        }
567
568        if den == 0 {
569            return Some(0);
570        }
571
572        // slope = num/den is in ms_offset / ms_elapsed = dimensionless ratio
573        // Convert to PPB: slope * 1_000_000_000
574        // = (num * 1_000_000_000) / den
575        let ppb = (num * 1_000_000_000) / den;
576        Some(ppb as i64)
577    }
578}
579
580// ---------------------------------------------------------------------------
581// NTP Client
582// ---------------------------------------------------------------------------
583
584/// Global NTP state: last known offset in milliseconds
585static NTP_OFFSET_MS: AtomicI64 = AtomicI64::new(0);
586
587/// Global NTP state: last sync timestamp (Unix epoch seconds)
588static LAST_SYNC_EPOCH: AtomicU64 = AtomicU64::new(0);
589
590/// Global NTP state: current poll interval in seconds
591static POLL_INTERVAL: AtomicU64 = AtomicU64::new(MIN_POLL_INTERVAL as u64);
592
593/// NTP client managing time synchronization with one or more servers.
594pub struct NtpClient {
595    /// Clock filter for selecting best sample
596    pub filter: ClockFilter,
597    /// Drift estimator for frequency compensation
598    pub drift: DriftEstimator,
599    /// Current poll interval in seconds
600    pub poll_interval: u32,
601    /// Our stratum level (one more than best server's stratum)
602    pub stratum: u8,
603    /// Whether a step correction has been applied since boot
604    pub step_applied: bool,
605    /// Measurement start time in ms (for drift tracking)
606    pub start_ms: u64,
607}
608
609impl Default for NtpClient {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615impl NtpClient {
616    /// Create a new NTP client.
617    pub fn new() -> Self {
618        Self {
619            filter: ClockFilter::new(),
620            drift: DriftEstimator::new(),
621            poll_interval: MIN_POLL_INTERVAL,
622            stratum: 16, // unsynchronized
623            step_applied: false,
624            start_ms: 0,
625        }
626    }
627
628    /// Create a client request packet for the given transmit timestamp.
629    pub fn create_request(&self, transmit_ts: NtpTimestamp) -> [u8; NTP_PACKET_SIZE] {
630        let pkt = NtpPacket::new_request(transmit_ts);
631        pkt.to_bytes()
632    }
633
634    /// Process a server response and return the computed measurement.
635    ///
636    /// - `response`: the raw 48-byte NTP response
637    /// - `t1`: our original transmit timestamp
638    /// - `t4`: our receive timestamp (when we got the response)
639    pub fn process_response(
640        &mut self,
641        response: &[u8],
642        t1: &NtpTimestamp,
643        t4: &NtpTimestamp,
644    ) -> Option<NtpMeasurement> {
645        let pkt = NtpPacket::from_bytes(response)?;
646
647        // Reject KoD packets
648        if pkt.is_kod() {
649            return None;
650        }
651
652        // Reject non-server responses
653        if pkt.mode != MODE_SERVER {
654            return None;
655        }
656
657        // Reject unsynchronized servers
658        if pkt.stratum == 0 || pkt.stratum >= 16 {
659            return None;
660        }
661
662        let t2 = &pkt.receive_ts;
663        let t3 = &pkt.transmit_ts;
664        let measurement = calculate_offset_delay(t1, t2, t3, t4);
665
666        // Reject if delay is unreasonable
667        if measurement.delay_ms.unsigned_abs() > MAX_DELAY_MS as u64 {
668            return None;
669        }
670
671        // Update our stratum
672        self.stratum = pkt.stratum.saturating_add(1).min(15);
673
674        // Add to clock filter
675        self.filter.add_sample(measurement);
676
677        // Add to drift estimator
678        let elapsed = t4.to_millis().saturating_sub(self.start_ms);
679        self.drift.add(elapsed, measurement.offset_ms);
680
681        // Handle leap indicator
682        if pkt.leap == LeapIndicator::AddSecond || pkt.leap == LeapIndicator::DeleteSecond {
683            // Log leap second warning (actual adjustment happens at midnight)
684            // In a full implementation this would schedule the leap second
685            // insertion
686        }
687
688        Some(measurement)
689    }
690
691    /// Decide whether to step (instant jump) or slew (gradual adjust) the
692    /// clock, and apply the correction via the RTC integration point.
693    ///
694    /// Returns the applied correction in milliseconds.
695    pub fn adjust_clock(&mut self) -> i64 {
696        let best = match self.filter.best_sample() {
697            Some(s) => s,
698            None => return 0,
699        };
700
701        let offset = best.offset_ms;
702
703        // Apply drift compensation
704        let drift_correction = match self.drift.drift_ppb() {
705            Some(ppb) => {
706                // drift_correction_ms = ppb * poll_interval_ms / 1_000_000_000
707                let poll_ms = (self.poll_interval as i64) * 1000;
708                (ppb * poll_ms) / 1_000_000_000
709            }
710            None => 0,
711        };
712
713        let total_correction = offset + drift_correction;
714
715        // Step or slew decision
716        if total_correction.unsigned_abs() > STEP_THRESHOLD_MS as u64 && !self.step_applied {
717            // Step: apply full correction immediately
718            apply_time_correction(total_correction);
719            self.step_applied = true;
720            // Reset poll interval after step
721            self.poll_interval = MIN_POLL_INTERVAL;
722        } else {
723            // Slew: apply correction gradually
724            apply_time_correction(total_correction);
725            // Increase poll interval on stable corrections
726            if total_correction.unsigned_abs() < 10 {
727                self.increase_poll_interval();
728            } else if total_correction.unsigned_abs() > 50 {
729                self.decrease_poll_interval();
730            }
731        }
732
733        NTP_OFFSET_MS.store(total_correction, Ordering::Relaxed);
734        total_correction
735    }
736
737    /// Increase poll interval (double, up to maximum).
738    fn increase_poll_interval(&mut self) {
739        let new_interval = self.poll_interval.saturating_mul(2).min(MAX_POLL_INTERVAL);
740        self.poll_interval = new_interval;
741        POLL_INTERVAL.store(new_interval as u64, Ordering::Relaxed);
742    }
743
744    /// Decrease poll interval (halve, down to minimum).
745    fn decrease_poll_interval(&mut self) {
746        let new_interval = (self.poll_interval / 2).max(MIN_POLL_INTERVAL);
747        self.poll_interval = new_interval;
748        POLL_INTERVAL.store(new_interval as u64, Ordering::Relaxed);
749    }
750
751    /// Get the current poll interval in seconds.
752    pub fn get_poll_interval(&self) -> u32 {
753        self.poll_interval
754    }
755}
756
757// ---------------------------------------------------------------------------
758// RTC Integration
759// ---------------------------------------------------------------------------
760
761/// Apply a time correction to the system clock via the RTC subsystem.
762///
763/// On x86_64, this calls into the RTC module's NTP correction interface.
764/// On other architectures, the offset is stored in the global atomic.
765fn apply_time_correction(offset_ms: i64) {
766    #[cfg(target_arch = "x86_64")]
767    {
768        crate::arch::x86_64::rtc::set_time_correction(offset_ms);
769    }
770    #[cfg(not(target_arch = "x86_64"))]
771    {
772        NTP_OFFSET_MS.store(offset_ms, Ordering::Relaxed);
773    }
774}
775
776/// Get the last NTP-computed offset in milliseconds.
777pub fn get_ntp_offset_ms() -> i64 {
778    NTP_OFFSET_MS.load(Ordering::Relaxed)
779}
780
781/// Get the last NTP sync time as Unix epoch seconds (0 = never synced).
782pub fn get_last_sync_epoch() -> u64 {
783    LAST_SYNC_EPOCH.load(Ordering::Relaxed)
784}
785
786/// Record that a successful NTP sync occurred at the given Unix epoch time.
787pub fn record_sync(epoch_secs: u64) {
788    LAST_SYNC_EPOCH.store(epoch_secs, Ordering::Relaxed);
789}
790
791/// Get the current poll interval in seconds.
792pub fn get_poll_interval() -> u64 {
793    POLL_INTERVAL.load(Ordering::Relaxed)
794}
795
796/// Trigger a boot-time NTP synchronization.
797///
798/// In a full implementation this would send a request to a configured
799/// NTP server via UDP port 123. Here we initialize the client state.
800pub fn boot_sync() -> NtpClient {
801    #[allow(unused_mut)]
802    let mut client = NtpClient::new();
803    // Set start time from system clock if available
804    #[cfg(target_arch = "x86_64")]
805    {
806        let epoch = crate::arch::x86_64::rtc::current_epoch_secs();
807        client.start_ms = epoch * 1000;
808    }
809    client
810}
811
812// ---------------------------------------------------------------------------
813// Tests
814// ---------------------------------------------------------------------------
815
816#[cfg(test)]
817mod tests {
818    #[allow(unused_imports)]
819    use alloc::vec;
820
821    use super::*;
822
823    // -- Timestamp tests --
824
825    #[test]
826    fn test_timestamp_to_from_bytes_roundtrip() {
827        let ts = NtpTimestamp::new(0xAABBCCDD, 0x11223344);
828        let bytes = ts.to_bytes();
829        let ts2 = NtpTimestamp::from_bytes(&bytes);
830        assert_eq!(ts, ts2);
831    }
832
833    #[test]
834    fn test_timestamp_zero() {
835        let ts = NtpTimestamp::new(0, 0);
836        let bytes = ts.to_bytes();
837        assert_eq!(bytes, [0u8; 8]);
838        let ts2 = NtpTimestamp::from_bytes(&bytes);
839        assert_eq!(ts2.seconds, 0);
840        assert_eq!(ts2.fraction, 0);
841    }
842
843    #[test]
844    fn test_timestamp_ntp_to_unix_epoch_conversion() {
845        // NTP epoch is 1900-01-01, Unix is 1970-01-01
846        // At NTP seconds = NTP_UNIX_OFFSET, Unix should be 0
847        let ts = NtpTimestamp::new(NTP_UNIX_OFFSET as u32, 0);
848        assert_eq!(ts.to_unix_secs(), 0);
849    }
850
851    #[test]
852    fn test_timestamp_unix_to_ntp_conversion() {
853        // Unix epoch 0 -> NTP seconds = NTP_UNIX_OFFSET
854        let ts = NtpTimestamp::from_unix(0, 0);
855        assert_eq!(ts.seconds, NTP_UNIX_OFFSET as u32);
856        assert_eq!(ts.fraction, 0);
857    }
858
859    #[test]
860    fn test_timestamp_millis_fraction() {
861        // 500ms = 0.5s -> fraction ~= 2^31 = 2147483648
862        let ts = NtpTimestamp::from_unix(0, 500);
863        // Allow small rounding: should be close to 2^31
864        let expected = 2_147_483_648u32;
865        let diff = if ts.fraction > expected {
866            ts.fraction - expected
867        } else {
868            expected - ts.fraction
869        };
870        assert!(
871            diff < 5000,
872            "fraction {} not close to {}",
873            ts.fraction,
874            expected
875        );
876    }
877
878    // -- Packet tests --
879
880    #[test]
881    fn test_packet_serialization_roundtrip() {
882        let ts = NtpTimestamp::new(1000, 2000);
883        let pkt = NtpPacket::new_request(ts);
884        let bytes = pkt.to_bytes();
885        let pkt2 = NtpPacket::from_bytes(&bytes).unwrap();
886        assert_eq!(pkt, pkt2);
887    }
888
889    #[test]
890    fn test_packet_too_short() {
891        let buf = [0u8; 10];
892        assert!(NtpPacket::from_bytes(&buf).is_none());
893    }
894
895    #[test]
896    fn test_packet_header_fields() {
897        let ts = NtpTimestamp::new(100, 200);
898        let pkt = NtpPacket::new_request(ts);
899        let bytes = pkt.to_bytes();
900        // Byte 0: LI=0(2), VN=4(3), Mode=3(3) -> 0b00_100_011 = 0x23
901        assert_eq!(bytes[0], 0x23);
902        assert_eq!(bytes[1], 0); // stratum
903        assert_eq!(bytes[2], 6); // poll
904    }
905
906    #[test]
907    fn test_kod_detection() {
908        let mut pkt = NtpPacket::new_request(NtpTimestamp::default());
909        pkt.mode = MODE_SERVER;
910        pkt.stratum = 0;
911        pkt.reference_id = *b"DENY";
912        assert!(pkt.is_kod());
913        assert_eq!(pkt.kod_code(), Some(*b"DENY"));
914    }
915
916    #[test]
917    fn test_kod_not_triggered_normal_packet() {
918        let mut pkt = NtpPacket::new_request(NtpTimestamp::default());
919        pkt.mode = MODE_SERVER;
920        pkt.stratum = 2;
921        pkt.reference_id = [192, 168, 1, 1]; // IP address, not ASCII
922        assert!(!pkt.is_kod());
923        assert_eq!(pkt.kod_code(), None);
924    }
925
926    // -- Clock offset / delay tests --
927
928    #[test]
929    fn test_clock_offset_symmetric() {
930        // Symmetric delay: server 100ms ahead
931        // t1=1000, t2=1100, t3=1100, t4=1000 (all in NTP seconds-as-ms)
932        let t1 = NtpTimestamp::new(1, 0); // 1000ms
933        let t2 = NtpTimestamp::new(2, 0); // 2000ms (server +1s)
934        let t3 = NtpTimestamp::new(2, 0); // 2000ms
935        let t4 = NtpTimestamp::new(1, 0); // 1000ms
936        let m = calculate_offset_delay(&t1, &t2, &t3, &t4);
937        // offset = ((2000-1000) + (2000-1000)) / 2 = 1000ms
938        assert_eq!(m.offset_ms, 1000);
939        // delay = (1000-1000) - (2000-2000) = 0
940        assert_eq!(m.delay_ms, 0);
941    }
942
943    #[test]
944    fn test_round_trip_delay() {
945        // t1=0s, t2=0.05s, t3=0.05s, t4=0.1s -> RTT=0.1s, offset=0
946        let t1 = NtpTimestamp::new(0, 0);
947        let t2 = NtpTimestamp::new(0, 214748365); // ~50ms
948        let t3 = NtpTimestamp::new(0, 214748365); // ~50ms
949        let t4 = NtpTimestamp::new(0, 429496730); // ~100ms
950        let m = calculate_offset_delay(&t1, &t2, &t3, &t4);
951        // offset should be ~0, delay should be ~100ms
952        assert!(m.offset_ms.unsigned_abs() <= 1, "offset: {}", m.offset_ms);
953        assert!(
954            m.delay_ms >= 95 && m.delay_ms <= 105,
955            "delay: {}",
956            m.delay_ms
957        );
958    }
959
960    // -- Clock filter tests --
961
962    #[test]
963    fn test_clock_filter_best_sample() {
964        let mut filter = ClockFilter::new();
965        filter.add_sample(NtpMeasurement {
966            offset_ms: 10,
967            delay_ms: 100,
968        });
969        filter.add_sample(NtpMeasurement {
970            offset_ms: 12,
971            delay_ms: 50,
972        });
973        filter.add_sample(NtpMeasurement {
974            offset_ms: 8,
975            delay_ms: 200,
976        });
977        let best = filter.best_sample().unwrap();
978        // Best = lowest delay = 50ms
979        assert_eq!(best.delay_ms, 50);
980        assert_eq!(best.offset_ms, 12);
981    }
982
983    #[test]
984    fn test_clock_filter_empty() {
985        let filter = ClockFilter::new();
986        assert!(filter.best_sample().is_none());
987        assert_eq!(filter.sample_count(), 0);
988    }
989
990    #[test]
991    fn test_clock_filter_jitter() {
992        let mut filter = ClockFilter::new();
993        // Add samples with increasing offsets: 10, 20, 30, 40
994        for i in 1..=4 {
995            filter.add_sample(NtpMeasurement {
996                offset_ms: i * 10,
997                delay_ms: 50,
998            });
999        }
1000        let jitter = filter.jitter_ms();
1001        // Differences: 10, 10, 10 -> RMS = sqrt(100) = 10
1002        assert_eq!(jitter, 10);
1003    }
1004
1005    // -- Marzullo's algorithm tests --
1006
1007    #[cfg(feature = "alloc")]
1008    #[test]
1009    fn test_marzullo_single_source() {
1010        let intervals = vec![TimeInterval { low: 90, high: 110 }];
1011        let result = marzullo_select(&intervals);
1012        assert_eq!(result, Some(100));
1013    }
1014
1015    #[cfg(feature = "alloc")]
1016    #[test]
1017    fn test_marzullo_overlapping_sources() {
1018        // Three sources agreeing roughly on offset=100
1019        let intervals = vec![
1020            TimeInterval { low: 90, high: 110 },
1021            TimeInterval { low: 95, high: 115 },
1022            TimeInterval { low: 85, high: 105 },
1023        ];
1024        let result = marzullo_select(&intervals).unwrap();
1025        // Majority intersection should be around 95..105, midpoint ~100
1026        assert!(result >= 90 && result <= 110, "result: {}", result);
1027    }
1028
1029    #[cfg(feature = "alloc")]
1030    #[test]
1031    fn test_marzullo_empty() {
1032        let intervals: Vec<TimeInterval> = vec![];
1033        assert!(marzullo_select(&intervals).is_none());
1034    }
1035
1036    // -- Drift estimator tests --
1037
1038    #[test]
1039    fn test_drift_estimator_too_few_samples() {
1040        let est = DriftEstimator::new();
1041        assert!(est.drift_ppb().is_none());
1042    }
1043
1044    #[test]
1045    fn test_drift_estimator_constant_offset() {
1046        let mut est = DriftEstimator::new();
1047        // Constant offset = no drift
1048        est.add(0, 100);
1049        est.add(1000, 100);
1050        est.add(2000, 100);
1051        let ppb = est.drift_ppb().unwrap();
1052        assert_eq!(ppb, 0);
1053    }
1054
1055    #[test]
1056    fn test_drift_estimator_linear_drift() {
1057        let mut est = DriftEstimator::new();
1058        // 1ms offset per 1000ms elapsed = 1 PPM = 1000 PPB
1059        est.add(0, 0);
1060        est.add(1000, 1);
1061        est.add(2000, 2);
1062        est.add(3000, 3);
1063        let ppb = est.drift_ppb().unwrap();
1064        assert_eq!(ppb, 1_000_000);
1065    }
1066
1067    // -- Poll interval tests --
1068
1069    #[test]
1070    fn test_poll_interval_bounds() {
1071        let mut client = NtpClient::new();
1072        assert_eq!(client.get_poll_interval(), MIN_POLL_INTERVAL);
1073
1074        // Increase multiple times
1075        for _ in 0..20 {
1076            client.increase_poll_interval();
1077        }
1078        assert!(client.get_poll_interval() <= MAX_POLL_INTERVAL);
1079        assert_eq!(client.get_poll_interval(), MAX_POLL_INTERVAL);
1080
1081        // Decrease multiple times
1082        for _ in 0..20 {
1083            client.decrease_poll_interval();
1084        }
1085        assert!(client.get_poll_interval() >= MIN_POLL_INTERVAL);
1086        assert_eq!(client.get_poll_interval(), MIN_POLL_INTERVAL);
1087    }
1088
1089    // -- Leap indicator tests --
1090
1091    #[test]
1092    fn test_leap_indicator_roundtrip() {
1093        for val in 0..=3u8 {
1094            let li = LeapIndicator::from_u8(val);
1095            assert_eq!(li as u8, val);
1096        }
1097    }
1098
1099    #[test]
1100    fn test_leap_indicator_in_packet() {
1101        let ts = NtpTimestamp::default();
1102        let mut pkt = NtpPacket::new_request(ts);
1103        pkt.leap = LeapIndicator::AddSecond;
1104        let bytes = pkt.to_bytes();
1105        let pkt2 = NtpPacket::from_bytes(&bytes).unwrap();
1106        assert_eq!(pkt2.leap, LeapIndicator::AddSecond);
1107    }
1108
1109    // -- Integer sqrt tests --
1110
1111    #[test]
1112    fn test_isqrt() {
1113        assert_eq!(isqrt(0), 0);
1114        assert_eq!(isqrt(1), 1);
1115        assert_eq!(isqrt(4), 2);
1116        assert_eq!(isqrt(9), 3);
1117        assert_eq!(isqrt(100), 10);
1118        assert_eq!(isqrt(99), 9); // floor
1119        assert_eq!(isqrt(1_000_000), 1000);
1120    }
1121}