⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
wireguard.rs

1//! WireGuard VPN tunnel implementation
2//!
3//! Implements the WireGuard protocol (Noise_IKpsk2 handshake pattern) for
4//! secure VPN tunneling. Provides:
5//! - BLAKE2s hash function (RFC 7693)
6//! - Noise IK handshake with pre-shared key
7//! - ChaCha20-Poly1305 AEAD transport encryption
8//! - Anti-replay sliding window
9//! - Peer management with key rotation
10//! - Virtual network interface (wg0)
11//! - Timer-based session management
12
13#![allow(dead_code)]
14
15use alloc::{collections::BTreeMap, vec::Vec};
16use core::sync::atomic::{AtomicU64, Ordering};
17
18use super::{IpAddress, Ipv4Address};
19use crate::crypto::{
20    cipher_suite::{CipherSuite, HmacAlgorithm, KdfAlgorithm},
21    hash::{blake2s_hash, blake2s_keyed_hash},
22};
23
24// ── Constants ────────────────────────────────────────────────────────────────
25
26/// Default WireGuard UDP port
27pub const DEFAULT_PORT: u16 = 51820;
28
29/// WireGuard message types
30const MSG_HANDSHAKE_INIT: u8 = 1;
31const MSG_HANDSHAKE_RESP: u8 = 2;
32const MSG_COOKIE_REPLY: u8 = 3;
33const MSG_TRANSPORT_DATA: u8 = 4;
34
35/// Handshake initiation message size (bytes)
36const HANDSHAKE_INIT_SIZE: usize = 148;
37
38/// Handshake response message size (bytes)
39const HANDSHAKE_RESP_SIZE: usize = 92;
40
41/// Key size in bytes (256-bit)
42const KEY_SIZE: usize = 32;
43
44/// Nonce size for ChaCha20-Poly1305 (96-bit)
45const CHACHA_NONCE_SIZE: usize = 12;
46
47/// Poly1305 authentication tag size (128-bit)
48const TAG_SIZE: usize = 16;
49
50/// Anti-replay window size in bits
51const REPLAY_WINDOW_BITS: usize = 2048;
52
53/// Anti-replay window size in u64 words
54const REPLAY_WINDOW_WORDS: usize = REPLAY_WINDOW_BITS / 64;
55
56/// Rekey after this many messages (2^60)
57const REKEY_AFTER_MESSAGES: u64 = 1u64 << 60;
58
59/// Rekey after this many seconds
60const REKEY_AFTER_SECONDS: u64 = 120;
61
62/// Session expires after this many seconds without data
63const SESSION_EXPIRY_SECONDS: u64 = 180;
64
65/// Default persistent keepalive interval (seconds)
66const DEFAULT_KEEPALIVE_INTERVAL: u64 = 25;
67
68/// Maximum handshake retry attempts
69const MAX_HANDSHAKE_RETRIES: u32 = 5;
70
71/// Initial handshake retry delay (milliseconds)
72const INITIAL_RETRY_DELAY_MS: u64 = 1000;
73
74/// WireGuard construction string (used in protocol derivation)
75const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
76
77/// WireGuard identifier string
78const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
79
80/// Transport message header: type(4) + receiver(4) + counter(8)
81const TRANSPORT_HEADER_SIZE: usize = 16;
82
83// ── BLAKE2s (delegates to crate::crypto::hash) ─────────────────────────────
84
85/// Compute BLAKE2s hash of data with given output length
86///
87/// Delegates to `crate::crypto::hash::blake2s_hash`.
88pub fn blake2s(data: &[u8], outlen: usize) -> [u8; 32] {
89    blake2s_hash(data, outlen)
90}
91
92/// Compute keyed BLAKE2s hash
93///
94/// Delegates to `crate::crypto::hash::blake2s_keyed_hash`.
95pub fn blake2s_keyed(key: &[u8], data: &[u8], outlen: usize) -> [u8; 32] {
96    blake2s_keyed_hash(key, data, outlen)
97}
98
99/// HMAC-BLAKE2s: delegates to `HmacAlgorithm::HmacBlake2s`
100pub fn hmac_blake2s(key: &[u8], data: &[u8]) -> [u8; 32] {
101    HmacAlgorithm::HmacBlake2s.compute(key, data)
102}
103
104/// HKDF-BLAKE2s key derivation (extract + expand, two outputs)
105///
106/// Delegates to `KdfAlgorithm::HkdfBlake2s`.
107fn hkdf(chaining_key: &[u8; 32], input: &[u8]) -> ([u8; 32], [u8; 32]) {
108    KdfAlgorithm::HkdfBlake2s.extract_expand2(chaining_key, input)
109}
110
111/// HKDF with three outputs
112///
113/// Delegates to `KdfAlgorithm::HkdfBlake2s`.
114fn hkdf3(chaining_key: &[u8; 32], input: &[u8]) -> ([u8; 32], [u8; 32], [u8; 32]) {
115    KdfAlgorithm::HkdfBlake2s.extract_expand3(chaining_key, input)
116}
117
118// ── X25519 Key Exchange (delegates to crate::crypto::asymmetric) ────────────
119
120/// X25519 key pair (Curve25519 Diffie-Hellman)
121///
122/// Uses the real X25519 scalar multiplication from `crate::crypto::asymmetric`
123/// for public key derivation and Diffie-Hellman key exchange.
124#[derive(Clone)]
125pub struct X25519KeyPair {
126    pub private_key: [u8; 32],
127    pub public_key: [u8; 32],
128}
129
130impl X25519KeyPair {
131    /// Generate a new key pair from a seed (deterministic for testing)
132    pub fn from_seed(seed: &[u8; 32]) -> Self {
133        let mut private_key = *seed;
134        // Clamp per RFC 7748
135        private_key[0] &= 248;
136        private_key[31] &= 127;
137        private_key[31] |= 64;
138
139        // Derive public key via real X25519 scalar multiplication with basepoint
140        let public_key = crate::crypto::asymmetric::x25519_scalar_mult(
141            &private_key,
142            &crate::crypto::asymmetric::X25519_BASEPOINT,
143        );
144        Self {
145            private_key,
146            public_key,
147        }
148    }
149
150    /// Perform Diffie-Hellman key exchange
151    pub fn dh(&self, their_public: &[u8; 32]) -> [u8; 32] {
152        crate::crypto::asymmetric::x25519_scalar_mult(&self.private_key, their_public)
153    }
154}
155
156// ── ChaCha20-Poly1305 AEAD (delegates to CipherSuite) ──────────────────────
157
158/// AEAD encrypt with ChaCha20-Poly1305
159///
160/// Uses `CipherSuite::ChaCha20Poly1305` from the shared crypto module.
161/// Returns ciphertext || 16-byte tag.
162fn aead_encrypt(key: &[u8; 32], nonce: u64, aad: &[u8], plaintext: &[u8]) -> Vec<u8> {
163    let mut nonce_bytes = [0u8; CHACHA_NONCE_SIZE];
164    nonce_bytes[4..12].copy_from_slice(&nonce.to_le_bytes());
165
166    CipherSuite::ChaCha20Poly1305
167        .encrypt_aead(key, &nonce_bytes, aad, plaintext)
168        .unwrap_or_default()
169}
170
171/// AEAD decrypt with ChaCha20-Poly1305
172///
173/// Uses `CipherSuite::ChaCha20Poly1305` from the shared crypto module.
174/// Returns plaintext or error if authentication fails.
175fn aead_decrypt(
176    key: &[u8; 32],
177    nonce: u64,
178    aad: &[u8],
179    ciphertext_and_tag: &[u8],
180) -> Result<Vec<u8>, WireGuardError> {
181    if ciphertext_and_tag.len() < TAG_SIZE {
182        return Err(WireGuardError::DecryptionFailed);
183    }
184
185    let mut nonce_bytes = [0u8; CHACHA_NONCE_SIZE];
186    nonce_bytes[4..12].copy_from_slice(&nonce.to_le_bytes());
187
188    CipherSuite::ChaCha20Poly1305
189        .decrypt_aead(key, &nonce_bytes, aad, ciphertext_and_tag)
190        .map_err(|_| WireGuardError::DecryptionFailed)
191}
192
193// ── Error Types ─────────────────────────────────────────────────────────────
194
195/// WireGuard protocol errors
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum WireGuardError {
198    /// Handshake failed or timed out
199    HandshakeFailed,
200    /// AEAD decryption or authentication failure
201    DecryptionFailed,
202    /// Message rejected by anti-replay window
203    ReplayDetected,
204    /// Session has expired
205    SessionExpired,
206    /// Peer not found
207    PeerNotFound,
208    /// Invalid message format
209    InvalidMessage,
210    /// Nonce counter overflow
211    NonceOverflow,
212    /// Key rotation required
213    RekeyRequired,
214    /// Interface not configured
215    NotConfigured,
216    /// Maximum peers reached
217    MaxPeersReached,
218}
219
220// ── Anti-Replay Window ──────────────────────────────────────────────────────
221
222/// Sliding-window anti-replay mechanism (2048-bit bitmap)
223#[derive(Clone)]
224pub struct AntiReplayWindow {
225    /// Highest accepted counter value
226    last_counter: u64,
227    /// Bitmap of recently seen counters (relative to last_counter)
228    bitmap: [u64; REPLAY_WINDOW_WORDS],
229}
230
231impl Default for AntiReplayWindow {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237impl AntiReplayWindow {
238    /// Create new anti-replay window
239    pub fn new() -> Self {
240        Self {
241            last_counter: 0,
242            bitmap: [0u64; REPLAY_WINDOW_WORDS],
243        }
244    }
245
246    /// Check if a counter value is acceptable (not a replay)
247    pub fn check(&self, counter: u64) -> bool {
248        if counter == 0 && self.last_counter == 0 && self.bitmap[0] == 0 {
249            // First packet ever
250            return true;
251        }
252        if counter > self.last_counter {
253            return true;
254        }
255        let diff = self.last_counter - counter;
256        if diff >= REPLAY_WINDOW_BITS as u64 {
257            return false; // Too old
258        }
259        let word_idx = (diff / 64) as usize;
260        let bit_idx = (diff % 64) as u32;
261        if word_idx >= REPLAY_WINDOW_WORDS {
262            return false;
263        }
264        (self.bitmap[word_idx] & (1u64 << bit_idx)) == 0
265    }
266
267    /// Update the window after accepting a packet
268    pub fn update(&mut self, counter: u64) {
269        if counter > self.last_counter {
270            let shift = counter - self.last_counter;
271            if shift >= REPLAY_WINDOW_BITS as u64 {
272                // Reset entire window
273                self.bitmap = [0u64; REPLAY_WINDOW_WORDS];
274            } else {
275                self.shift_window(shift as usize);
276            }
277            self.last_counter = counter;
278            // Mark current counter as seen (bit 0)
279            self.bitmap[0] |= 1;
280        } else {
281            let diff = self.last_counter - counter;
282            let word_idx = (diff / 64) as usize;
283            let bit_idx = (diff % 64) as u32;
284            if word_idx < REPLAY_WINDOW_WORDS {
285                self.bitmap[word_idx] |= 1u64 << bit_idx;
286            }
287        }
288    }
289
290    /// Shift the bitmap window by the given number of positions
291    fn shift_window(&mut self, shift: usize) {
292        let word_shift = shift / 64;
293        let bit_shift = (shift % 64) as u32;
294
295        if word_shift >= REPLAY_WINDOW_WORDS {
296            self.bitmap = [0u64; REPLAY_WINDOW_WORDS];
297            return;
298        }
299
300        // Shift by whole words
301        if word_shift > 0 {
302            let mut i = REPLAY_WINDOW_WORDS;
303            while i > word_shift {
304                i -= 1;
305                self.bitmap[i] = self.bitmap[i - word_shift];
306            }
307            let mut j = 0;
308            while j < word_shift {
309                self.bitmap[j] = 0;
310                j += 1;
311            }
312        }
313
314        // Shift by remaining bits
315        if bit_shift > 0 {
316            let mut i = REPLAY_WINDOW_WORDS;
317            while i > 1 {
318                i -= 1;
319                self.bitmap[i] =
320                    (self.bitmap[i] << bit_shift) | (self.bitmap[i - 1] >> (64 - bit_shift));
321            }
322            self.bitmap[0] <<= bit_shift;
323        }
324    }
325}
326
327// ── Handshake State ─────────────────────────────────────────────────────────
328
329/// Handshake state machine
330#[derive(Debug, Clone, Copy, PartialEq, Eq)]
331pub enum HandshakeState {
332    /// No handshake in progress
333    None,
334    /// Initiation sent, waiting for response
335    InitSent,
336    /// Initiation received (responder)
337    InitReceived,
338    /// Handshake complete, session established
339    Established,
340}
341
342/// Session keys derived from handshake
343pub struct SessionKeys {
344    /// Key for sending
345    pub sending_key: [u8; 32],
346    /// Key for receiving
347    pub receiving_key: [u8; 32],
348    /// Sending nonce counter
349    pub sending_nonce: AtomicU64,
350    /// Time when keys were derived (uptime seconds)
351    pub created_at: u64,
352    /// Number of messages sent with these keys
353    pub messages_sent: AtomicU64,
354}
355
356impl SessionKeys {
357    /// Create new session keys
358    pub fn new(sending_key: [u8; 32], receiving_key: [u8; 32], now: u64) -> Self {
359        Self {
360            sending_key,
361            receiving_key,
362            sending_nonce: AtomicU64::new(0),
363            created_at: now,
364            messages_sent: AtomicU64::new(0),
365        }
366    }
367
368    /// Check if keys need rotation
369    pub fn needs_rekey(&self, now: u64) -> bool {
370        let messages = self.messages_sent.load(Ordering::Relaxed);
371        let age = now.saturating_sub(self.created_at);
372        messages >= REKEY_AFTER_MESSAGES || age >= REKEY_AFTER_SECONDS
373    }
374
375    /// Get and increment the sending nonce
376    pub fn next_nonce(&self) -> Result<u64, WireGuardError> {
377        let nonce = self.sending_nonce.fetch_add(1, Ordering::Relaxed);
378        if nonce >= REKEY_AFTER_MESSAGES {
379            return Err(WireGuardError::NonceOverflow);
380        }
381        self.messages_sent.fetch_add(1, Ordering::Relaxed);
382        Ok(nonce)
383    }
384}
385
386/// Handshake context for Noise_IKpsk2
387#[derive(Clone)]
388pub struct HandshakeContext {
389    /// Chaining key
390    pub chaining_key: [u8; 32],
391    /// Hash state
392    pub hash: [u8; 32],
393    /// Local ephemeral key pair
394    pub ephemeral: Option<X25519KeyPair>,
395    /// Remote ephemeral public key
396    pub remote_ephemeral: Option<[u8; 32]>,
397    /// Our sender index
398    pub sender_index: u32,
399    /// Their sender index
400    pub receiver_index: u32,
401}
402
403impl Default for HandshakeContext {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409impl HandshakeContext {
410    /// Initialize a new handshake context
411    pub fn new() -> Self {
412        // Initial chaining key = BLAKE2s(CONSTRUCTION)
413        let chaining_key = blake2s(CONSTRUCTION, 32);
414        // Initial hash = BLAKE2s(chaining_key || IDENTIFIER)
415        let mut hash_input = Vec::with_capacity(32 + IDENTIFIER.len());
416        hash_input.extend_from_slice(&chaining_key);
417        hash_input.extend_from_slice(IDENTIFIER);
418        let hash = blake2s(&hash_input, 32);
419        Self {
420            chaining_key,
421            hash,
422            ephemeral: None,
423            remote_ephemeral: None,
424            sender_index: 0,
425            receiver_index: 0,
426        }
427    }
428
429    /// Mix hash: h = BLAKE2s(h || data)
430    pub fn mix_hash(&mut self, data: &[u8]) {
431        let mut input = Vec::with_capacity(32 + data.len());
432        input.extend_from_slice(&self.hash);
433        input.extend_from_slice(data);
434        self.hash = blake2s(&input, 32);
435    }
436
437    /// Build handshake initiation message (148 bytes)
438    pub fn create_initiation(
439        &mut self,
440        static_key: &X25519KeyPair,
441        remote_static_pub: &[u8; 32],
442        preshared_key: &[u8; 32],
443        timestamp: &[u8; 12],
444        sender_index: u32,
445    ) -> [u8; HANDSHAKE_INIT_SIZE] {
446        let mut msg = [0u8; HANDSHAKE_INIT_SIZE];
447        self.sender_index = sender_index;
448
449        // Mix responder's static public key into hash
450        self.mix_hash(remote_static_pub);
451
452        // Generate ephemeral key (deterministic from static key + timestamp for test)
453        let seed = hmac_blake2s(&static_key.private_key, timestamp);
454        let ephemeral = X25519KeyPair::from_seed(&seed);
455
456        // msg[0..4] = type + reserved
457        msg[0] = MSG_HANDSHAKE_INIT;
458        // msg[4..8] = sender index
459        msg[4..8].copy_from_slice(&sender_index.to_le_bytes());
460        // msg[8..40] = unencrypted ephemeral
461        msg[8..40].copy_from_slice(&ephemeral.public_key);
462        self.mix_hash(&ephemeral.public_key);
463
464        // DH: ephemeral <-> remote static
465        let dh_result = ephemeral.dh(remote_static_pub);
466        let (ck, key) = hkdf(&self.chaining_key, &dh_result);
467        self.chaining_key = ck;
468
469        // msg[40..88] = AEAD(key, 0, static_pub, h)
470        let encrypted_static = aead_encrypt(&key, 0, &self.hash, &static_key.public_key);
471        let copy_len = core::cmp::min(encrypted_static.len(), 48);
472        msg[40..40 + copy_len].copy_from_slice(&encrypted_static[..copy_len]);
473        self.mix_hash(&msg[40..40 + copy_len]);
474
475        // DH: static <-> remote static
476        let dh_result2 = static_key.dh(remote_static_pub);
477        let (ck2, key2) = hkdf(&self.chaining_key, &dh_result2);
478        self.chaining_key = ck2;
479
480        // msg[88..116] = AEAD(key2, 0, timestamp, h)
481        let encrypted_ts = aead_encrypt(&key2, 0, &self.hash, timestamp);
482        let ts_len = core::cmp::min(encrypted_ts.len(), 28);
483        msg[88..88 + ts_len].copy_from_slice(&encrypted_ts[..ts_len]);
484        self.mix_hash(&msg[88..88 + ts_len]);
485
486        // PSK mixing
487        let (ck3, psk_key) = hkdf(&self.chaining_key, preshared_key);
488        self.chaining_key = ck3;
489        let _ = psk_key; // Used for MAC in full implementation
490
491        // msg[116..132] = MAC1 (BLAKE2s of msg[0..116] keyed with remote static hash)
492        let mac_key = blake2s(remote_static_pub, 32);
493        let mac1 = blake2s_keyed(&mac_key, &msg[..116], 16);
494        msg[116..132].copy_from_slice(&mac1[..16]);
495
496        // msg[132..148] = MAC2 (cookie, zero if no cookie)
497        // Left as zero (no cookie)
498
499        self.ephemeral = Some(ephemeral);
500        msg
501    }
502
503    /// Build handshake response message (92 bytes)
504    pub fn create_response(
505        &mut self,
506        static_key: &X25519KeyPair,
507        remote_static_pub: &[u8; 32],
508        preshared_key: &[u8; 32],
509        sender_index: u32,
510        receiver_index: u32,
511    ) -> [u8; HANDSHAKE_RESP_SIZE] {
512        let mut msg = [0u8; HANDSHAKE_RESP_SIZE];
513
514        // Generate ephemeral
515        let seed = hmac_blake2s(&static_key.private_key, &sender_index.to_le_bytes());
516        let ephemeral = X25519KeyPair::from_seed(&seed);
517
518        // msg[0..4] = type + reserved
519        msg[0] = MSG_HANDSHAKE_RESP;
520        // msg[4..8] = sender index
521        msg[4..8].copy_from_slice(&sender_index.to_le_bytes());
522        // msg[8..12] = receiver index
523        msg[8..12].copy_from_slice(&receiver_index.to_le_bytes());
524        // msg[12..44] = unencrypted ephemeral
525        msg[12..44].copy_from_slice(&ephemeral.public_key);
526        self.mix_hash(&ephemeral.public_key);
527
528        // DH: responder ephemeral <-> initiator ephemeral
529        if let Some(ref remote_eph) = self.remote_ephemeral {
530            let dh1 = ephemeral.dh(remote_eph);
531            let (ck, _) = hkdf(&self.chaining_key, &dh1);
532            self.chaining_key = ck;
533        }
534
535        // DH: responder ephemeral <-> initiator static
536        let dh2 = ephemeral.dh(remote_static_pub);
537        let (ck2, _) = hkdf(&self.chaining_key, &dh2);
538        self.chaining_key = ck2;
539
540        // PSK mixing
541        let (ck3, tau, key) = hkdf3(&self.chaining_key, preshared_key);
542        self.chaining_key = ck3;
543        self.mix_hash(&tau);
544
545        // msg[44..60] = AEAD(key, 0, empty, h) -- encrypted nothing
546        let encrypted_empty = aead_encrypt(&key, 0, &self.hash, &[]);
547        let empty_len = core::cmp::min(encrypted_empty.len(), 16);
548        msg[44..44 + empty_len].copy_from_slice(&encrypted_empty[..empty_len]);
549        self.mix_hash(&msg[44..44 + empty_len]);
550
551        // msg[60..76] = MAC1
552        let mac_key = blake2s(remote_static_pub, 32);
553        let mac1 = blake2s_keyed(&mac_key, &msg[..60], 16);
554        msg[60..76].copy_from_slice(&mac1[..16]);
555
556        // msg[76..92] = MAC2 (zero if no cookie)
557
558        self.sender_index = sender_index;
559        self.receiver_index = receiver_index;
560        self.ephemeral = Some(ephemeral);
561        msg
562    }
563
564    /// Derive transport keys from completed handshake
565    pub fn derive_transport_keys(&self) -> (SessionKeys, SessionKeys) {
566        let (t1, t2) = hkdf(&self.chaining_key, &[]);
567        let now = 0u64; // Caller should provide real timestamp
568        (SessionKeys::new(t1, t2, now), SessionKeys::new(t2, t1, now))
569    }
570}
571
572// ── Peer Management ─────────────────────────────────────────────────────────
573
574/// Allowed IP range for a peer
575#[derive(Debug, Clone, Copy, PartialEq, Eq)]
576pub struct AllowedIp {
577    /// Network address
578    pub address: Ipv4Address,
579    /// Prefix length (CIDR notation)
580    pub prefix_len: u8,
581}
582
583impl AllowedIp {
584    /// Create new allowed IP range
585    pub fn new(address: Ipv4Address, prefix_len: u8) -> Self {
586        Self {
587            address,
588            prefix_len,
589        }
590    }
591
592    /// Check if an IP address matches this allowed range
593    pub fn matches(&self, ip: &Ipv4Address) -> bool {
594        if self.prefix_len == 0 {
595            return true; // 0.0.0.0/0 matches everything
596        }
597        if self.prefix_len >= 32 {
598            return self.address == *ip;
599        }
600        let mask = u32::MAX << (32 - self.prefix_len);
601        (self.address.to_u32() & mask) == (ip.to_u32() & mask)
602    }
603}
604
605/// WireGuard peer
606pub struct WireGuardPeer {
607    /// Peer's static public key
608    pub public_key: [u8; 32],
609    /// Pre-shared key (optional, all zeros if none)
610    pub preshared_key: [u8; 32],
611    /// Peer's endpoint (IP:port)
612    pub endpoint: Option<super::SocketAddr>,
613    /// Allowed IP ranges
614    pub allowed_ips: Vec<AllowedIp>,
615    /// Handshake state
616    pub handshake_state: HandshakeState,
617    /// Current handshake context
618    pub handshake: HandshakeContext,
619    /// Current session keys
620    pub session: Option<SessionKeys>,
621    /// Anti-replay window
622    pub replay_window: AntiReplayWindow,
623    /// Last handshake timestamp (uptime seconds)
624    pub last_handshake: u64,
625    /// Last data received timestamp
626    pub last_received: u64,
627    /// Last data sent timestamp
628    pub last_sent: u64,
629    /// Persistent keepalive interval (0 = disabled)
630    pub keepalive_interval: u64,
631    /// Handshake retry count
632    pub handshake_retries: u32,
633    /// Next retry time (uptime milliseconds)
634    pub next_retry_ms: u64,
635    /// Bytes transmitted
636    pub tx_bytes: u64,
637    /// Bytes received
638    pub rx_bytes: u64,
639}
640
641impl WireGuardPeer {
642    /// Create a new peer
643    pub fn new(public_key: [u8; 32]) -> Self {
644        Self {
645            public_key,
646            preshared_key: [0u8; 32],
647            endpoint: None,
648            allowed_ips: Vec::new(),
649            handshake_state: HandshakeState::None,
650            handshake: HandshakeContext::new(),
651            session: None,
652            replay_window: AntiReplayWindow::new(),
653            last_handshake: 0,
654            last_received: 0,
655            last_sent: 0,
656            keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL,
657            handshake_retries: 0,
658            next_retry_ms: 0,
659            tx_bytes: 0,
660            rx_bytes: 0,
661        }
662    }
663
664    /// Set pre-shared key
665    pub fn set_preshared_key(&mut self, psk: [u8; 32]) {
666        self.preshared_key = psk;
667    }
668
669    /// Add an allowed IP range
670    pub fn add_allowed_ip(&mut self, ip: AllowedIp) {
671        self.allowed_ips.push(ip);
672    }
673
674    /// Check if a destination IP is allowed for this peer
675    pub fn is_allowed(&self, ip: &Ipv4Address) -> bool {
676        self.allowed_ips.iter().any(|aip| aip.matches(ip))
677    }
678
679    /// Check if session has expired
680    pub fn is_session_expired(&self, now: u64) -> bool {
681        if self.last_received == 0 && self.last_sent == 0 {
682            return false; // No session yet
683        }
684        let last_activity = core::cmp::max(self.last_received, self.last_sent);
685        now.saturating_sub(last_activity) >= SESSION_EXPIRY_SECONDS
686    }
687
688    /// Check if keepalive should be sent
689    pub fn needs_keepalive(&self, now: u64) -> bool {
690        if self.keepalive_interval == 0 {
691            return false;
692        }
693        if self.handshake_state != HandshakeState::Established {
694            return false;
695        }
696        now.saturating_sub(self.last_sent) >= self.keepalive_interval
697    }
698
699    /// Calculate next handshake retry delay with exponential backoff (ms)
700    pub fn retry_delay_ms(&self) -> u64 {
701        if self.handshake_retries >= MAX_HANDSHAKE_RETRIES {
702            return 0; // Give up
703        }
704        // Exponential backoff: 1s, 2s, 4s, 8s, 16s
705        let mut delay = INITIAL_RETRY_DELAY_MS;
706        let mut i = 0u32;
707        while i < self.handshake_retries {
708            delay = delay.saturating_mul(2);
709            i += 1;
710        }
711        delay
712    }
713}
714
715// ── Transport ───────────────────────────────────────────────────────────────
716
717/// Encrypt a transport data packet
718pub fn encrypt_transport(
719    session: &SessionKeys,
720    receiver_index: u32,
721    payload: &[u8],
722) -> Result<Vec<u8>, WireGuardError> {
723    let nonce = session.next_nonce()?;
724
725    // Pad payload to 16-byte boundary
726    let padded_len = (payload.len() + 15) & !15;
727    let mut padded = Vec::with_capacity(padded_len);
728    padded.extend_from_slice(payload);
729    padded.resize(padded_len, 0);
730
731    // Encrypt
732    let encrypted = aead_encrypt(&session.sending_key, nonce, &[], &padded);
733
734    // Build transport message: type(4) + receiver(4) + counter(8) + encrypted
735    let mut msg = Vec::with_capacity(TRANSPORT_HEADER_SIZE + encrypted.len());
736    msg.extend_from_slice(&[MSG_TRANSPORT_DATA, 0, 0, 0]);
737    msg.extend_from_slice(&receiver_index.to_le_bytes());
738    msg.extend_from_slice(&nonce.to_le_bytes());
739    msg.extend_from_slice(&encrypted);
740
741    Ok(msg)
742}
743
744/// Decrypt a transport data packet
745pub fn decrypt_transport(
746    session: &SessionKeys,
747    replay_window: &mut AntiReplayWindow,
748    packet: &[u8],
749) -> Result<Vec<u8>, WireGuardError> {
750    if packet.len() < TRANSPORT_HEADER_SIZE + TAG_SIZE {
751        return Err(WireGuardError::InvalidMessage);
752    }
753    if packet[0] != MSG_TRANSPORT_DATA {
754        return Err(WireGuardError::InvalidMessage);
755    }
756
757    let counter = u64::from_le_bytes([
758        packet[8], packet[9], packet[10], packet[11], packet[12], packet[13], packet[14],
759        packet[15],
760    ]);
761
762    // Anti-replay check
763    if !replay_window.check(counter) {
764        return Err(WireGuardError::ReplayDetected);
765    }
766
767    // Decrypt
768    let plaintext = aead_decrypt(
769        &session.receiving_key,
770        counter,
771        &[],
772        &packet[TRANSPORT_HEADER_SIZE..],
773    )?;
774
775    // Update replay window after successful decryption
776    replay_window.update(counter);
777
778    Ok(plaintext)
779}
780
781// ── Virtual Interface ───────────────────────────────────────────────────────
782
783/// WireGuard virtual network interface (wg0)
784pub struct WireGuardInterface {
785    /// Interface name
786    pub name: [u8; 16],
787    /// Local static key pair
788    pub static_key: X25519KeyPair,
789    /// Listening UDP port
790    pub listen_port: u16,
791    /// Tunnel IP address
792    pub tunnel_address: Option<IpAddress>,
793    /// Tunnel subnet prefix length
794    pub tunnel_prefix: u8,
795    /// Peer table: hash of public key -> peer
796    pub peers: BTreeMap<u64, WireGuardPeer>,
797    /// Interface MTU
798    pub mtu: u16,
799    /// Whether the interface is up
800    pub is_up: bool,
801    /// Next sender index to assign
802    next_sender_index: u32,
803    /// Packet counter for statistics
804    pub packets_in: u64,
805    pub packets_out: u64,
806}
807
808impl WireGuardInterface {
809    /// Create a new WireGuard interface
810    pub fn new(name: &[u8], static_key: X25519KeyPair, listen_port: u16) -> Self {
811        let mut name_buf = [0u8; 16];
812        let copy_len = core::cmp::min(name.len(), 15);
813        name_buf[..copy_len].copy_from_slice(&name[..copy_len]);
814        Self {
815            name: name_buf,
816            static_key,
817            listen_port,
818            tunnel_address: None,
819            tunnel_prefix: 24,
820            peers: BTreeMap::new(),
821            mtu: 1420, // Standard WireGuard MTU (1500 - 80)
822            is_up: false,
823            next_sender_index: 1,
824            packets_in: 0,
825            packets_out: 0,
826        }
827    }
828
829    /// Set tunnel IP address
830    pub fn set_address(&mut self, addr: IpAddress, prefix: u8) {
831        self.tunnel_address = Some(addr);
832        self.tunnel_prefix = prefix;
833    }
834
835    /// Calculate effective MTU based on outer transport.
836    ///
837    /// - IPv4 outer: outer_mtu - 20 (IP) - 8 (UDP) - 32 (WG overhead) =
838    ///   outer_mtu - 60
839    /// - IPv6 outer: outer_mtu - 40 (IP) - 8 (UDP) - 32 (WG overhead) =
840    ///   outer_mtu - 80
841    pub fn calculate_mtu(outer_mtu: u16, is_ipv6: bool) -> u16 {
842        let overhead = if is_ipv6 { 80u16 } else { 60u16 };
843        outer_mtu.saturating_sub(overhead)
844    }
845
846    /// Compute a hash key for a peer's public key
847    fn peer_key_hash(public_key: &[u8; 32]) -> u64 {
848        let hash = blake2s(public_key, 32);
849        u64::from_le_bytes([
850            hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7],
851        ])
852    }
853
854    /// Add a peer
855    pub fn add_peer(&mut self, peer: WireGuardPeer) -> Result<(), WireGuardError> {
856        let key = Self::peer_key_hash(&peer.public_key);
857        self.peers.insert(key, peer);
858        Ok(())
859    }
860
861    /// Remove a peer by public key
862    pub fn remove_peer(&mut self, public_key: &[u8; 32]) -> Result<(), WireGuardError> {
863        let key = Self::peer_key_hash(public_key);
864        self.peers
865            .remove(&key)
866            .map(|_| ())
867            .ok_or(WireGuardError::PeerNotFound)
868    }
869
870    /// Look up a peer by public key
871    pub fn get_peer(&self, public_key: &[u8; 32]) -> Option<&WireGuardPeer> {
872        let key = Self::peer_key_hash(public_key);
873        self.peers.get(&key)
874    }
875
876    /// Look up a peer mutably by public key
877    pub fn get_peer_mut(&mut self, public_key: &[u8; 32]) -> Option<&mut WireGuardPeer> {
878        let key = Self::peer_key_hash(public_key);
879        self.peers.get_mut(&key)
880    }
881
882    /// Find a peer that handles a given destination IP
883    pub fn find_peer_for_ip(&self, dst: &Ipv4Address) -> Option<&WireGuardPeer> {
884        self.peers.values().find(|peer| peer.is_allowed(dst))
885    }
886
887    /// Bring the interface up
888    pub fn up(&mut self) -> Result<(), WireGuardError> {
889        if self.tunnel_address.is_none() {
890            return Err(WireGuardError::NotConfigured);
891        }
892        self.is_up = true;
893        Ok(())
894    }
895
896    /// Bring the interface down
897    pub fn down(&mut self) {
898        self.is_up = false;
899    }
900
901    /// Allocate a new sender index
902    pub fn alloc_sender_index(&mut self) -> u32 {
903        let idx = self.next_sender_index;
904        self.next_sender_index = self.next_sender_index.wrapping_add(1);
905        if self.next_sender_index == 0 {
906            self.next_sender_index = 1;
907        }
908        idx
909    }
910
911    /// Get peer count
912    pub fn peer_count(&self) -> usize {
913        self.peers.len()
914    }
915}
916
917// ── Timer Management ────────────────────────────────────────────────────────
918
919/// Timer events for WireGuard session management
920#[derive(Debug, Clone, Copy, PartialEq, Eq)]
921pub enum TimerEvent {
922    /// Time to initiate a rekey
923    RekeyInitiate,
924    /// Handshake retry needed
925    HandshakeRetry,
926    /// Session has expired, clear keys
927    SessionExpiry,
928    /// Dead peer detected
929    DeadPeer,
930    /// Send keepalive
931    Keepalive,
932}
933
934/// Timer state for a peer
935pub struct PeerTimers {
936    /// Handshake initiated timestamp (ms)
937    pub handshake_initiated_ms: u64,
938    /// Last keepalive sent timestamp (s)
939    pub last_keepalive_sent: u64,
940    /// Whether a rekey is pending
941    pub rekey_pending: bool,
942}
943
944impl Default for PeerTimers {
945    fn default() -> Self {
946        Self::new()
947    }
948}
949
950impl PeerTimers {
951    pub fn new() -> Self {
952        Self {
953            handshake_initiated_ms: 0,
954            last_keepalive_sent: 0,
955            rekey_pending: false,
956        }
957    }
958}
959
960/// Check timer events for a peer
961pub fn check_peer_timers(
962    peer: &WireGuardPeer,
963    timers: &PeerTimers,
964    now_secs: u64,
965    now_ms: u64,
966) -> Option<TimerEvent> {
967    // Check session expiry first (highest priority)
968    if peer.is_session_expired(now_secs) {
969        return Some(TimerEvent::SessionExpiry);
970    }
971
972    // Check rekey needed
973    if let Some(ref session) = peer.session {
974        if session.needs_rekey(now_secs) && !timers.rekey_pending {
975            return Some(TimerEvent::RekeyInitiate);
976        }
977    }
978
979    // Check handshake retry
980    if peer.handshake_state == HandshakeState::InitSent {
981        if peer.handshake_retries >= MAX_HANDSHAKE_RETRIES {
982            return Some(TimerEvent::DeadPeer);
983        }
984        if now_ms >= peer.next_retry_ms && peer.next_retry_ms > 0 {
985            return Some(TimerEvent::HandshakeRetry);
986        }
987    }
988
989    // Check keepalive
990    if peer.needs_keepalive(now_secs) {
991        return Some(TimerEvent::Keepalive);
992    }
993
994    None
995}
996
997// ── Tests ───────────────────────────────────────────────────────────────────
998
999#[cfg(test)]
1000mod tests {
1001    #[allow(unused_imports)]
1002    use alloc::vec;
1003
1004    use super::*;
1005
1006    // ── BLAKE2s Tests ───────────────────────────────────────────────────
1007
1008    #[test]
1009    fn test_blake2s_empty_input() {
1010        // RFC 7693 Appendix A: BLAKE2s-256("")
1011        let hash = blake2s(b"", 32);
1012        // Known BLAKE2s-256 of empty string
1013        assert_eq!(hash[0], 0x69);
1014        assert_eq!(hash[1], 0x21);
1015        assert_eq!(hash[2], 0x7a);
1016        assert_eq!(hash[3], 0x30);
1017    }
1018
1019    #[test]
1020    fn test_blake2s_abc() {
1021        // BLAKE2s-256("abc") known test vector
1022        let hash = blake2s(b"abc", 32);
1023        assert_eq!(hash[0], 0x50);
1024        assert_eq!(hash[1], 0x8C);
1025        // Non-zero output
1026        assert!(hash.iter().any(|&b| b != 0));
1027    }
1028
1029    #[test]
1030    fn test_blake2s_deterministic() {
1031        let h1 = blake2s(b"test data", 32);
1032        let h2 = blake2s(b"test data", 32);
1033        assert_eq!(h1, h2);
1034    }
1035
1036    #[test]
1037    fn test_blake2s_different_inputs() {
1038        let h1 = blake2s(b"hello", 32);
1039        let h2 = blake2s(b"world", 32);
1040        assert_ne!(h1, h2);
1041    }
1042
1043    #[test]
1044    fn test_blake2s_keyed_mode() {
1045        let key = [0x42u8; 32];
1046        let h1 = blake2s_keyed(&key, b"data", 32);
1047        let h2 = blake2s(b"data", 32);
1048        // Keyed hash should differ from unkeyed
1049        assert_ne!(h1, h2);
1050        // Keyed hash is deterministic
1051        let h3 = blake2s_keyed(&key, b"data", 32);
1052        assert_eq!(h1, h3);
1053    }
1054
1055    #[test]
1056    fn test_blake2s_keyed_different_keys() {
1057        let key1 = [0x01u8; 32];
1058        let key2 = [0x02u8; 32];
1059        let h1 = blake2s_keyed(&key1, b"data", 32);
1060        let h2 = blake2s_keyed(&key2, b"data", 32);
1061        assert_ne!(h1, h2);
1062    }
1063
1064    #[test]
1065    fn test_hmac_blake2s() {
1066        let key = [0xABu8; 32];
1067        let mac1 = hmac_blake2s(&key, b"message");
1068        let mac2 = hmac_blake2s(&key, b"message");
1069        assert_eq!(mac1, mac2);
1070
1071        let mac3 = hmac_blake2s(&key, b"different message");
1072        assert_ne!(mac1, mac3);
1073    }
1074
1075    // ── Anti-Replay Window Tests ────────────────────────────────────────
1076
1077    #[test]
1078    fn test_replay_window_accept_new() {
1079        let mut window = AntiReplayWindow::new();
1080        assert!(window.check(0));
1081        window.update(0);
1082        assert!(window.check(1));
1083        window.update(1);
1084        assert!(window.check(2));
1085        window.update(2);
1086        assert!(window.check(100));
1087    }
1088
1089    #[test]
1090    fn test_replay_window_reject_duplicate() {
1091        let mut window = AntiReplayWindow::new();
1092        window.update(5);
1093        assert!(!window.check(5)); // Already seen
1094    }
1095
1096    #[test]
1097    fn test_replay_window_reject_old() {
1098        let mut window = AntiReplayWindow::new();
1099        window.update(3000);
1100        // Counter 0 is now outside the 2048-bit window
1101        assert!(!window.check(0));
1102    }
1103
1104    #[test]
1105    fn test_replay_window_accept_within_window() {
1106        let mut window = AntiReplayWindow::new();
1107        window.update(100);
1108        // Counter 99 is within window and not yet seen
1109        assert!(window.check(99));
1110        window.update(99);
1111        // But now it's seen
1112        assert!(!window.check(99));
1113    }
1114
1115    #[test]
1116    fn test_replay_window_large_jump() {
1117        let mut window = AntiReplayWindow::new();
1118        window.update(0);
1119        window.update(10000);
1120        // All old counters are outside the window
1121        assert!(!window.check(0));
1122        assert!(!window.check(100));
1123        // New counter is accepted
1124        assert!(window.check(10001));
1125    }
1126
1127    // ── Peer Management Tests ───────────────────────────────────────────
1128
1129    #[test]
1130    fn test_peer_add_remove() {
1131        let seed = [1u8; 32];
1132        let key = X25519KeyPair::from_seed(&seed);
1133        let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1134
1135        let peer_pub = [0x42u8; 32];
1136        let peer = WireGuardPeer::new(peer_pub);
1137        assert!(iface.add_peer(peer).is_ok());
1138        assert_eq!(iface.peer_count(), 1);
1139
1140        assert!(iface.get_peer(&peer_pub).is_some());
1141        assert!(iface.remove_peer(&peer_pub).is_ok());
1142        assert_eq!(iface.peer_count(), 0);
1143        assert!(iface.get_peer(&peer_pub).is_none());
1144    }
1145
1146    #[test]
1147    fn test_peer_remove_not_found() {
1148        let seed = [1u8; 32];
1149        let key = X25519KeyPair::from_seed(&seed);
1150        let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1151
1152        let fake_pub = [0xFFu8; 32];
1153        assert_eq!(
1154            iface.remove_peer(&fake_pub),
1155            Err(WireGuardError::PeerNotFound)
1156        );
1157    }
1158
1159    #[test]
1160    fn test_peer_lookup() {
1161        let seed = [1u8; 32];
1162        let key = X25519KeyPair::from_seed(&seed);
1163        let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1164
1165        let pub1 = [0x01u8; 32];
1166        let pub2 = [0x02u8; 32];
1167        iface.add_peer(WireGuardPeer::new(pub1)).unwrap();
1168        iface.add_peer(WireGuardPeer::new(pub2)).unwrap();
1169
1170        assert!(iface.get_peer(&pub1).is_some());
1171        assert!(iface.get_peer(&pub2).is_some());
1172        assert_eq!(iface.peer_count(), 2);
1173    }
1174
1175    // ── Key Rotation and Timer Tests ────────────────────────────────────
1176
1177    #[test]
1178    fn test_key_rotation_by_time() {
1179        let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1180        assert!(!keys.needs_rekey(0));
1181        assert!(!keys.needs_rekey(119));
1182        assert!(keys.needs_rekey(120)); // REKEY_AFTER_SECONDS
1183        assert!(keys.needs_rekey(200));
1184    }
1185
1186    #[test]
1187    fn test_key_rotation_by_messages() {
1188        let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1189        // Simulate many messages by setting counter directly
1190        keys.messages_sent
1191            .store(REKEY_AFTER_MESSAGES, Ordering::Relaxed);
1192        assert!(keys.needs_rekey(0));
1193    }
1194
1195    // ── MTU Calculation Tests ───────────────────────────────────────────
1196
1197    #[test]
1198    fn test_mtu_calculation_ipv4() {
1199        // Standard 1500 MTU - 60 (IPv4 overhead) = 1440
1200        assert_eq!(WireGuardInterface::calculate_mtu(1500, false), 1440);
1201    }
1202
1203    #[test]
1204    fn test_mtu_calculation_ipv6() {
1205        // Standard 1500 MTU - 80 (IPv6 overhead) = 1420
1206        assert_eq!(WireGuardInterface::calculate_mtu(1500, true), 1420);
1207    }
1208
1209    #[test]
1210    fn test_mtu_calculation_small() {
1211        // Saturating subtraction prevents underflow
1212        assert_eq!(WireGuardInterface::calculate_mtu(50, false), 0);
1213        assert_eq!(WireGuardInterface::calculate_mtu(60, false), 0);
1214        assert_eq!(WireGuardInterface::calculate_mtu(61, false), 1);
1215    }
1216
1217    // ── Nonce Counter Tests ─────────────────────────────────────────────
1218
1219    #[test]
1220    fn test_nonce_counter_increment() {
1221        let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1222        assert_eq!(keys.next_nonce().unwrap(), 0);
1223        assert_eq!(keys.next_nonce().unwrap(), 1);
1224        assert_eq!(keys.next_nonce().unwrap(), 2);
1225    }
1226
1227    #[test]
1228    fn test_nonce_counter_overflow() {
1229        let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1230        keys.sending_nonce
1231            .store(REKEY_AFTER_MESSAGES, Ordering::Relaxed);
1232        assert_eq!(keys.next_nonce(), Err(WireGuardError::NonceOverflow));
1233    }
1234
1235    // ── Allowed IP Matching Tests ───────────────────────────────────────
1236
1237    #[test]
1238    fn test_allowed_ip_exact_match() {
1239        let aip = AllowedIp::new(Ipv4Address::new(10, 0, 0, 1), 32);
1240        assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1241        assert!(!aip.matches(&Ipv4Address::new(10, 0, 0, 2)));
1242    }
1243
1244    #[test]
1245    fn test_allowed_ip_subnet_match() {
1246        let aip = AllowedIp::new(Ipv4Address::new(10, 0, 0, 0), 24);
1247        assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1248        assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 254)));
1249        assert!(!aip.matches(&Ipv4Address::new(10, 0, 1, 1)));
1250    }
1251
1252    #[test]
1253    fn test_allowed_ip_wildcard() {
1254        let aip = AllowedIp::new(Ipv4Address::new(0, 0, 0, 0), 0);
1255        assert!(aip.matches(&Ipv4Address::new(192, 168, 1, 1)));
1256        assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1257    }
1258
1259    // ── Session State Transitions ───────────────────────────────────────
1260
1261    #[test]
1262    fn test_session_state_transitions() {
1263        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1264        assert_eq!(peer.handshake_state, HandshakeState::None);
1265
1266        peer.handshake_state = HandshakeState::InitSent;
1267        assert_eq!(peer.handshake_state, HandshakeState::InitSent);
1268
1269        peer.handshake_state = HandshakeState::Established;
1270        assert_eq!(peer.handshake_state, HandshakeState::Established);
1271    }
1272
1273    #[test]
1274    fn test_session_expiry() {
1275        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1276        peer.last_received = 100;
1277        peer.last_sent = 100;
1278
1279        assert!(!peer.is_session_expired(200)); // 100s elapsed < 180s threshold
1280        assert!(peer.is_session_expired(281)); // 181s elapsed >= 180s threshold
1281    }
1282
1283    #[test]
1284    fn test_handshake_retry_backoff() {
1285        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1286        peer.handshake_retries = 0;
1287        assert_eq!(peer.retry_delay_ms(), 1000);
1288
1289        peer.handshake_retries = 1;
1290        assert_eq!(peer.retry_delay_ms(), 2000);
1291
1292        peer.handshake_retries = 2;
1293        assert_eq!(peer.retry_delay_ms(), 4000);
1294
1295        peer.handshake_retries = 3;
1296        assert_eq!(peer.retry_delay_ms(), 8000);
1297
1298        peer.handshake_retries = 4;
1299        assert_eq!(peer.retry_delay_ms(), 16000);
1300
1301        // Max retries exceeded -> give up
1302        peer.handshake_retries = MAX_HANDSHAKE_RETRIES;
1303        assert_eq!(peer.retry_delay_ms(), 0);
1304    }
1305
1306    // ── Handshake Message Construction ──────────────────────────────────
1307
1308    #[test]
1309    fn test_handshake_initiation_size() {
1310        let static_key = X25519KeyPair::from_seed(&[0x11u8; 32]);
1311        let remote_pub = [0x22u8; 32];
1312        let psk = [0x33u8; 32];
1313        let timestamp = [0u8; 12];
1314
1315        let mut ctx = HandshakeContext::new();
1316        let msg = ctx.create_initiation(&static_key, &remote_pub, &psk, &timestamp, 1);
1317
1318        assert_eq!(msg.len(), HANDSHAKE_INIT_SIZE);
1319        assert_eq!(msg[0], MSG_HANDSHAKE_INIT);
1320        // Sender index at offset 4
1321        assert_eq!(u32::from_le_bytes([msg[4], msg[5], msg[6], msg[7]]), 1);
1322    }
1323
1324    #[test]
1325    fn test_handshake_response_size() {
1326        let static_key = X25519KeyPair::from_seed(&[0x44u8; 32]);
1327        let remote_pub = [0x55u8; 32];
1328        let psk = [0x66u8; 32];
1329
1330        let mut ctx = HandshakeContext::new();
1331        let msg = ctx.create_response(&static_key, &remote_pub, &psk, 2, 1);
1332
1333        assert_eq!(msg.len(), HANDSHAKE_RESP_SIZE);
1334        assert_eq!(msg[0], MSG_HANDSHAKE_RESP);
1335        // Sender index at offset 4
1336        assert_eq!(u32::from_le_bytes([msg[4], msg[5], msg[6], msg[7]]), 2);
1337        // Receiver index at offset 8
1338        assert_eq!(u32::from_le_bytes([msg[8], msg[9], msg[10], msg[11]]), 1);
1339    }
1340
1341    // ── Transport Encrypt/Decrypt ───────────────────────────────────────
1342
1343    #[test]
1344    fn test_transport_encrypt_decrypt() {
1345        let send_key = [0xAAu8; 32];
1346        let recv_key = [0xAAu8; 32]; // Same for round-trip test
1347        let send_session = SessionKeys::new(send_key, [0u8; 32], 0);
1348        let recv_session = SessionKeys::new([0u8; 32], recv_key, 0);
1349
1350        let payload = b"hello wireguard";
1351        let encrypted = encrypt_transport(&send_session, 42, payload).unwrap();
1352
1353        // Header: type(4) + receiver(4) + counter(8)
1354        assert_eq!(encrypted[0], MSG_TRANSPORT_DATA);
1355        assert_eq!(
1356            u32::from_le_bytes([encrypted[4], encrypted[5], encrypted[6], encrypted[7]]),
1357            42
1358        );
1359
1360        let mut window = AntiReplayWindow::new();
1361        let decrypted = decrypt_transport(&recv_session, &mut window, &encrypted).unwrap();
1362        // Decrypted payload is padded to 16-byte boundary
1363        assert!(decrypted.len() >= payload.len());
1364        assert_eq!(&decrypted[..payload.len()], payload);
1365    }
1366
1367    // ── Timer Event Tests ───────────────────────────────────────────────
1368
1369    #[test]
1370    fn test_keepalive_timing() {
1371        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1372        peer.handshake_state = HandshakeState::Established;
1373        peer.keepalive_interval = 25;
1374        peer.last_sent = 100;
1375
1376        assert!(!peer.needs_keepalive(120)); // 20s < 25s
1377        assert!(peer.needs_keepalive(125)); // 25s >= 25s
1378        assert!(peer.needs_keepalive(200)); // 100s >= 25s
1379    }
1380
1381    #[test]
1382    fn test_keepalive_disabled() {
1383        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1384        peer.handshake_state = HandshakeState::Established;
1385        peer.keepalive_interval = 0;
1386        peer.last_sent = 0;
1387
1388        assert!(!peer.needs_keepalive(1000));
1389    }
1390
1391    #[test]
1392    fn test_timer_event_session_expiry() {
1393        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1394        peer.last_received = 100;
1395        peer.last_sent = 100;
1396        let timers = PeerTimers::new();
1397
1398        let event = check_peer_timers(&peer, &timers, 300, 300_000);
1399        assert_eq!(event, Some(TimerEvent::SessionExpiry));
1400    }
1401
1402    #[test]
1403    fn test_timer_event_dead_peer() {
1404        let mut peer = WireGuardPeer::new([0x01u8; 32]);
1405        peer.handshake_state = HandshakeState::InitSent;
1406        peer.handshake_retries = MAX_HANDSHAKE_RETRIES;
1407        let timers = PeerTimers::new();
1408
1409        let event = check_peer_timers(&peer, &timers, 0, 0);
1410        assert_eq!(event, Some(TimerEvent::DeadPeer));
1411    }
1412
1413    #[test]
1414    fn test_interface_up_down() {
1415        let seed = [1u8; 32];
1416        let key = X25519KeyPair::from_seed(&seed);
1417        let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1418
1419        // Cannot bring up without address
1420        assert_eq!(iface.up(), Err(WireGuardError::NotConfigured));
1421
1422        iface.set_address(IpAddress::V4(Ipv4Address::new(10, 0, 0, 1)), 24);
1423        assert!(iface.up().is_ok());
1424        assert!(iface.is_up);
1425
1426        iface.down();
1427        assert!(!iface.is_up);
1428    }
1429}