1#![allow(dead_code)]
14
15use alloc::{collections::BTreeMap, vec::Vec};
16use core::sync::atomic::{AtomicU64, Ordering};
17
18use super::{IpAddress, Ipv4Address};
19use crate::crypto::{
20 cipher_suite::{CipherSuite, HmacAlgorithm, KdfAlgorithm},
21 hash::{blake2s_hash, blake2s_keyed_hash},
22};
23
24pub const DEFAULT_PORT: u16 = 51820;
28
29const MSG_HANDSHAKE_INIT: u8 = 1;
31const MSG_HANDSHAKE_RESP: u8 = 2;
32const MSG_COOKIE_REPLY: u8 = 3;
33const MSG_TRANSPORT_DATA: u8 = 4;
34
35const HANDSHAKE_INIT_SIZE: usize = 148;
37
38const HANDSHAKE_RESP_SIZE: usize = 92;
40
41const KEY_SIZE: usize = 32;
43
44const CHACHA_NONCE_SIZE: usize = 12;
46
47const TAG_SIZE: usize = 16;
49
50const REPLAY_WINDOW_BITS: usize = 2048;
52
53const REPLAY_WINDOW_WORDS: usize = REPLAY_WINDOW_BITS / 64;
55
56const REKEY_AFTER_MESSAGES: u64 = 1u64 << 60;
58
59const REKEY_AFTER_SECONDS: u64 = 120;
61
62const SESSION_EXPIRY_SECONDS: u64 = 180;
64
65const DEFAULT_KEEPALIVE_INTERVAL: u64 = 25;
67
68const MAX_HANDSHAKE_RETRIES: u32 = 5;
70
71const INITIAL_RETRY_DELAY_MS: u64 = 1000;
73
74const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
76
77const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
79
80const TRANSPORT_HEADER_SIZE: usize = 16;
82
83pub fn blake2s(data: &[u8], outlen: usize) -> [u8; 32] {
89 blake2s_hash(data, outlen)
90}
91
92pub fn blake2s_keyed(key: &[u8], data: &[u8], outlen: usize) -> [u8; 32] {
96 blake2s_keyed_hash(key, data, outlen)
97}
98
99pub fn hmac_blake2s(key: &[u8], data: &[u8]) -> [u8; 32] {
101 HmacAlgorithm::HmacBlake2s.compute(key, data)
102}
103
104fn hkdf(chaining_key: &[u8; 32], input: &[u8]) -> ([u8; 32], [u8; 32]) {
108 KdfAlgorithm::HkdfBlake2s.extract_expand2(chaining_key, input)
109}
110
111fn hkdf3(chaining_key: &[u8; 32], input: &[u8]) -> ([u8; 32], [u8; 32], [u8; 32]) {
115 KdfAlgorithm::HkdfBlake2s.extract_expand3(chaining_key, input)
116}
117
118#[derive(Clone)]
125pub struct X25519KeyPair {
126 pub private_key: [u8; 32],
127 pub public_key: [u8; 32],
128}
129
130impl X25519KeyPair {
131 pub fn from_seed(seed: &[u8; 32]) -> Self {
133 let mut private_key = *seed;
134 private_key[0] &= 248;
136 private_key[31] &= 127;
137 private_key[31] |= 64;
138
139 let public_key = crate::crypto::asymmetric::x25519_scalar_mult(
141 &private_key,
142 &crate::crypto::asymmetric::X25519_BASEPOINT,
143 );
144 Self {
145 private_key,
146 public_key,
147 }
148 }
149
150 pub fn dh(&self, their_public: &[u8; 32]) -> [u8; 32] {
152 crate::crypto::asymmetric::x25519_scalar_mult(&self.private_key, their_public)
153 }
154}
155
156fn aead_encrypt(key: &[u8; 32], nonce: u64, aad: &[u8], plaintext: &[u8]) -> Vec<u8> {
163 let mut nonce_bytes = [0u8; CHACHA_NONCE_SIZE];
164 nonce_bytes[4..12].copy_from_slice(&nonce.to_le_bytes());
165
166 CipherSuite::ChaCha20Poly1305
167 .encrypt_aead(key, &nonce_bytes, aad, plaintext)
168 .unwrap_or_default()
169}
170
171fn aead_decrypt(
176 key: &[u8; 32],
177 nonce: u64,
178 aad: &[u8],
179 ciphertext_and_tag: &[u8],
180) -> Result<Vec<u8>, WireGuardError> {
181 if ciphertext_and_tag.len() < TAG_SIZE {
182 return Err(WireGuardError::DecryptionFailed);
183 }
184
185 let mut nonce_bytes = [0u8; CHACHA_NONCE_SIZE];
186 nonce_bytes[4..12].copy_from_slice(&nonce.to_le_bytes());
187
188 CipherSuite::ChaCha20Poly1305
189 .decrypt_aead(key, &nonce_bytes, aad, ciphertext_and_tag)
190 .map_err(|_| WireGuardError::DecryptionFailed)
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum WireGuardError {
198 HandshakeFailed,
200 DecryptionFailed,
202 ReplayDetected,
204 SessionExpired,
206 PeerNotFound,
208 InvalidMessage,
210 NonceOverflow,
212 RekeyRequired,
214 NotConfigured,
216 MaxPeersReached,
218}
219
220#[derive(Clone)]
224pub struct AntiReplayWindow {
225 last_counter: u64,
227 bitmap: [u64; REPLAY_WINDOW_WORDS],
229}
230
231impl Default for AntiReplayWindow {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237impl AntiReplayWindow {
238 pub fn new() -> Self {
240 Self {
241 last_counter: 0,
242 bitmap: [0u64; REPLAY_WINDOW_WORDS],
243 }
244 }
245
246 pub fn check(&self, counter: u64) -> bool {
248 if counter == 0 && self.last_counter == 0 && self.bitmap[0] == 0 {
249 return true;
251 }
252 if counter > self.last_counter {
253 return true;
254 }
255 let diff = self.last_counter - counter;
256 if diff >= REPLAY_WINDOW_BITS as u64 {
257 return false; }
259 let word_idx = (diff / 64) as usize;
260 let bit_idx = (diff % 64) as u32;
261 if word_idx >= REPLAY_WINDOW_WORDS {
262 return false;
263 }
264 (self.bitmap[word_idx] & (1u64 << bit_idx)) == 0
265 }
266
267 pub fn update(&mut self, counter: u64) {
269 if counter > self.last_counter {
270 let shift = counter - self.last_counter;
271 if shift >= REPLAY_WINDOW_BITS as u64 {
272 self.bitmap = [0u64; REPLAY_WINDOW_WORDS];
274 } else {
275 self.shift_window(shift as usize);
276 }
277 self.last_counter = counter;
278 self.bitmap[0] |= 1;
280 } else {
281 let diff = self.last_counter - counter;
282 let word_idx = (diff / 64) as usize;
283 let bit_idx = (diff % 64) as u32;
284 if word_idx < REPLAY_WINDOW_WORDS {
285 self.bitmap[word_idx] |= 1u64 << bit_idx;
286 }
287 }
288 }
289
290 fn shift_window(&mut self, shift: usize) {
292 let word_shift = shift / 64;
293 let bit_shift = (shift % 64) as u32;
294
295 if word_shift >= REPLAY_WINDOW_WORDS {
296 self.bitmap = [0u64; REPLAY_WINDOW_WORDS];
297 return;
298 }
299
300 if word_shift > 0 {
302 let mut i = REPLAY_WINDOW_WORDS;
303 while i > word_shift {
304 i -= 1;
305 self.bitmap[i] = self.bitmap[i - word_shift];
306 }
307 let mut j = 0;
308 while j < word_shift {
309 self.bitmap[j] = 0;
310 j += 1;
311 }
312 }
313
314 if bit_shift > 0 {
316 let mut i = REPLAY_WINDOW_WORDS;
317 while i > 1 {
318 i -= 1;
319 self.bitmap[i] =
320 (self.bitmap[i] << bit_shift) | (self.bitmap[i - 1] >> (64 - bit_shift));
321 }
322 self.bitmap[0] <<= bit_shift;
323 }
324 }
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq)]
331pub enum HandshakeState {
332 None,
334 InitSent,
336 InitReceived,
338 Established,
340}
341
342pub struct SessionKeys {
344 pub sending_key: [u8; 32],
346 pub receiving_key: [u8; 32],
348 pub sending_nonce: AtomicU64,
350 pub created_at: u64,
352 pub messages_sent: AtomicU64,
354}
355
356impl SessionKeys {
357 pub fn new(sending_key: [u8; 32], receiving_key: [u8; 32], now: u64) -> Self {
359 Self {
360 sending_key,
361 receiving_key,
362 sending_nonce: AtomicU64::new(0),
363 created_at: now,
364 messages_sent: AtomicU64::new(0),
365 }
366 }
367
368 pub fn needs_rekey(&self, now: u64) -> bool {
370 let messages = self.messages_sent.load(Ordering::Relaxed);
371 let age = now.saturating_sub(self.created_at);
372 messages >= REKEY_AFTER_MESSAGES || age >= REKEY_AFTER_SECONDS
373 }
374
375 pub fn next_nonce(&self) -> Result<u64, WireGuardError> {
377 let nonce = self.sending_nonce.fetch_add(1, Ordering::Relaxed);
378 if nonce >= REKEY_AFTER_MESSAGES {
379 return Err(WireGuardError::NonceOverflow);
380 }
381 self.messages_sent.fetch_add(1, Ordering::Relaxed);
382 Ok(nonce)
383 }
384}
385
386#[derive(Clone)]
388pub struct HandshakeContext {
389 pub chaining_key: [u8; 32],
391 pub hash: [u8; 32],
393 pub ephemeral: Option<X25519KeyPair>,
395 pub remote_ephemeral: Option<[u8; 32]>,
397 pub sender_index: u32,
399 pub receiver_index: u32,
401}
402
403impl Default for HandshakeContext {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409impl HandshakeContext {
410 pub fn new() -> Self {
412 let chaining_key = blake2s(CONSTRUCTION, 32);
414 let mut hash_input = Vec::with_capacity(32 + IDENTIFIER.len());
416 hash_input.extend_from_slice(&chaining_key);
417 hash_input.extend_from_slice(IDENTIFIER);
418 let hash = blake2s(&hash_input, 32);
419 Self {
420 chaining_key,
421 hash,
422 ephemeral: None,
423 remote_ephemeral: None,
424 sender_index: 0,
425 receiver_index: 0,
426 }
427 }
428
429 pub fn mix_hash(&mut self, data: &[u8]) {
431 let mut input = Vec::with_capacity(32 + data.len());
432 input.extend_from_slice(&self.hash);
433 input.extend_from_slice(data);
434 self.hash = blake2s(&input, 32);
435 }
436
437 pub fn create_initiation(
439 &mut self,
440 static_key: &X25519KeyPair,
441 remote_static_pub: &[u8; 32],
442 preshared_key: &[u8; 32],
443 timestamp: &[u8; 12],
444 sender_index: u32,
445 ) -> [u8; HANDSHAKE_INIT_SIZE] {
446 let mut msg = [0u8; HANDSHAKE_INIT_SIZE];
447 self.sender_index = sender_index;
448
449 self.mix_hash(remote_static_pub);
451
452 let seed = hmac_blake2s(&static_key.private_key, timestamp);
454 let ephemeral = X25519KeyPair::from_seed(&seed);
455
456 msg[0] = MSG_HANDSHAKE_INIT;
458 msg[4..8].copy_from_slice(&sender_index.to_le_bytes());
460 msg[8..40].copy_from_slice(&ephemeral.public_key);
462 self.mix_hash(&ephemeral.public_key);
463
464 let dh_result = ephemeral.dh(remote_static_pub);
466 let (ck, key) = hkdf(&self.chaining_key, &dh_result);
467 self.chaining_key = ck;
468
469 let encrypted_static = aead_encrypt(&key, 0, &self.hash, &static_key.public_key);
471 let copy_len = core::cmp::min(encrypted_static.len(), 48);
472 msg[40..40 + copy_len].copy_from_slice(&encrypted_static[..copy_len]);
473 self.mix_hash(&msg[40..40 + copy_len]);
474
475 let dh_result2 = static_key.dh(remote_static_pub);
477 let (ck2, key2) = hkdf(&self.chaining_key, &dh_result2);
478 self.chaining_key = ck2;
479
480 let encrypted_ts = aead_encrypt(&key2, 0, &self.hash, timestamp);
482 let ts_len = core::cmp::min(encrypted_ts.len(), 28);
483 msg[88..88 + ts_len].copy_from_slice(&encrypted_ts[..ts_len]);
484 self.mix_hash(&msg[88..88 + ts_len]);
485
486 let (ck3, psk_key) = hkdf(&self.chaining_key, preshared_key);
488 self.chaining_key = ck3;
489 let _ = psk_key; let mac_key = blake2s(remote_static_pub, 32);
493 let mac1 = blake2s_keyed(&mac_key, &msg[..116], 16);
494 msg[116..132].copy_from_slice(&mac1[..16]);
495
496 self.ephemeral = Some(ephemeral);
500 msg
501 }
502
503 pub fn create_response(
505 &mut self,
506 static_key: &X25519KeyPair,
507 remote_static_pub: &[u8; 32],
508 preshared_key: &[u8; 32],
509 sender_index: u32,
510 receiver_index: u32,
511 ) -> [u8; HANDSHAKE_RESP_SIZE] {
512 let mut msg = [0u8; HANDSHAKE_RESP_SIZE];
513
514 let seed = hmac_blake2s(&static_key.private_key, &sender_index.to_le_bytes());
516 let ephemeral = X25519KeyPair::from_seed(&seed);
517
518 msg[0] = MSG_HANDSHAKE_RESP;
520 msg[4..8].copy_from_slice(&sender_index.to_le_bytes());
522 msg[8..12].copy_from_slice(&receiver_index.to_le_bytes());
524 msg[12..44].copy_from_slice(&ephemeral.public_key);
526 self.mix_hash(&ephemeral.public_key);
527
528 if let Some(ref remote_eph) = self.remote_ephemeral {
530 let dh1 = ephemeral.dh(remote_eph);
531 let (ck, _) = hkdf(&self.chaining_key, &dh1);
532 self.chaining_key = ck;
533 }
534
535 let dh2 = ephemeral.dh(remote_static_pub);
537 let (ck2, _) = hkdf(&self.chaining_key, &dh2);
538 self.chaining_key = ck2;
539
540 let (ck3, tau, key) = hkdf3(&self.chaining_key, preshared_key);
542 self.chaining_key = ck3;
543 self.mix_hash(&tau);
544
545 let encrypted_empty = aead_encrypt(&key, 0, &self.hash, &[]);
547 let empty_len = core::cmp::min(encrypted_empty.len(), 16);
548 msg[44..44 + empty_len].copy_from_slice(&encrypted_empty[..empty_len]);
549 self.mix_hash(&msg[44..44 + empty_len]);
550
551 let mac_key = blake2s(remote_static_pub, 32);
553 let mac1 = blake2s_keyed(&mac_key, &msg[..60], 16);
554 msg[60..76].copy_from_slice(&mac1[..16]);
555
556 self.sender_index = sender_index;
559 self.receiver_index = receiver_index;
560 self.ephemeral = Some(ephemeral);
561 msg
562 }
563
564 pub fn derive_transport_keys(&self) -> (SessionKeys, SessionKeys) {
566 let (t1, t2) = hkdf(&self.chaining_key, &[]);
567 let now = 0u64; (SessionKeys::new(t1, t2, now), SessionKeys::new(t2, t1, now))
569 }
570}
571
572#[derive(Debug, Clone, Copy, PartialEq, Eq)]
576pub struct AllowedIp {
577 pub address: Ipv4Address,
579 pub prefix_len: u8,
581}
582
583impl AllowedIp {
584 pub fn new(address: Ipv4Address, prefix_len: u8) -> Self {
586 Self {
587 address,
588 prefix_len,
589 }
590 }
591
592 pub fn matches(&self, ip: &Ipv4Address) -> bool {
594 if self.prefix_len == 0 {
595 return true; }
597 if self.prefix_len >= 32 {
598 return self.address == *ip;
599 }
600 let mask = u32::MAX << (32 - self.prefix_len);
601 (self.address.to_u32() & mask) == (ip.to_u32() & mask)
602 }
603}
604
605pub struct WireGuardPeer {
607 pub public_key: [u8; 32],
609 pub preshared_key: [u8; 32],
611 pub endpoint: Option<super::SocketAddr>,
613 pub allowed_ips: Vec<AllowedIp>,
615 pub handshake_state: HandshakeState,
617 pub handshake: HandshakeContext,
619 pub session: Option<SessionKeys>,
621 pub replay_window: AntiReplayWindow,
623 pub last_handshake: u64,
625 pub last_received: u64,
627 pub last_sent: u64,
629 pub keepalive_interval: u64,
631 pub handshake_retries: u32,
633 pub next_retry_ms: u64,
635 pub tx_bytes: u64,
637 pub rx_bytes: u64,
639}
640
641impl WireGuardPeer {
642 pub fn new(public_key: [u8; 32]) -> Self {
644 Self {
645 public_key,
646 preshared_key: [0u8; 32],
647 endpoint: None,
648 allowed_ips: Vec::new(),
649 handshake_state: HandshakeState::None,
650 handshake: HandshakeContext::new(),
651 session: None,
652 replay_window: AntiReplayWindow::new(),
653 last_handshake: 0,
654 last_received: 0,
655 last_sent: 0,
656 keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL,
657 handshake_retries: 0,
658 next_retry_ms: 0,
659 tx_bytes: 0,
660 rx_bytes: 0,
661 }
662 }
663
664 pub fn set_preshared_key(&mut self, psk: [u8; 32]) {
666 self.preshared_key = psk;
667 }
668
669 pub fn add_allowed_ip(&mut self, ip: AllowedIp) {
671 self.allowed_ips.push(ip);
672 }
673
674 pub fn is_allowed(&self, ip: &Ipv4Address) -> bool {
676 self.allowed_ips.iter().any(|aip| aip.matches(ip))
677 }
678
679 pub fn is_session_expired(&self, now: u64) -> bool {
681 if self.last_received == 0 && self.last_sent == 0 {
682 return false; }
684 let last_activity = core::cmp::max(self.last_received, self.last_sent);
685 now.saturating_sub(last_activity) >= SESSION_EXPIRY_SECONDS
686 }
687
688 pub fn needs_keepalive(&self, now: u64) -> bool {
690 if self.keepalive_interval == 0 {
691 return false;
692 }
693 if self.handshake_state != HandshakeState::Established {
694 return false;
695 }
696 now.saturating_sub(self.last_sent) >= self.keepalive_interval
697 }
698
699 pub fn retry_delay_ms(&self) -> u64 {
701 if self.handshake_retries >= MAX_HANDSHAKE_RETRIES {
702 return 0; }
704 let mut delay = INITIAL_RETRY_DELAY_MS;
706 let mut i = 0u32;
707 while i < self.handshake_retries {
708 delay = delay.saturating_mul(2);
709 i += 1;
710 }
711 delay
712 }
713}
714
715pub fn encrypt_transport(
719 session: &SessionKeys,
720 receiver_index: u32,
721 payload: &[u8],
722) -> Result<Vec<u8>, WireGuardError> {
723 let nonce = session.next_nonce()?;
724
725 let padded_len = (payload.len() + 15) & !15;
727 let mut padded = Vec::with_capacity(padded_len);
728 padded.extend_from_slice(payload);
729 padded.resize(padded_len, 0);
730
731 let encrypted = aead_encrypt(&session.sending_key, nonce, &[], &padded);
733
734 let mut msg = Vec::with_capacity(TRANSPORT_HEADER_SIZE + encrypted.len());
736 msg.extend_from_slice(&[MSG_TRANSPORT_DATA, 0, 0, 0]);
737 msg.extend_from_slice(&receiver_index.to_le_bytes());
738 msg.extend_from_slice(&nonce.to_le_bytes());
739 msg.extend_from_slice(&encrypted);
740
741 Ok(msg)
742}
743
744pub fn decrypt_transport(
746 session: &SessionKeys,
747 replay_window: &mut AntiReplayWindow,
748 packet: &[u8],
749) -> Result<Vec<u8>, WireGuardError> {
750 if packet.len() < TRANSPORT_HEADER_SIZE + TAG_SIZE {
751 return Err(WireGuardError::InvalidMessage);
752 }
753 if packet[0] != MSG_TRANSPORT_DATA {
754 return Err(WireGuardError::InvalidMessage);
755 }
756
757 let counter = u64::from_le_bytes([
758 packet[8], packet[9], packet[10], packet[11], packet[12], packet[13], packet[14],
759 packet[15],
760 ]);
761
762 if !replay_window.check(counter) {
764 return Err(WireGuardError::ReplayDetected);
765 }
766
767 let plaintext = aead_decrypt(
769 &session.receiving_key,
770 counter,
771 &[],
772 &packet[TRANSPORT_HEADER_SIZE..],
773 )?;
774
775 replay_window.update(counter);
777
778 Ok(plaintext)
779}
780
781pub struct WireGuardInterface {
785 pub name: [u8; 16],
787 pub static_key: X25519KeyPair,
789 pub listen_port: u16,
791 pub tunnel_address: Option<IpAddress>,
793 pub tunnel_prefix: u8,
795 pub peers: BTreeMap<u64, WireGuardPeer>,
797 pub mtu: u16,
799 pub is_up: bool,
801 next_sender_index: u32,
803 pub packets_in: u64,
805 pub packets_out: u64,
806}
807
808impl WireGuardInterface {
809 pub fn new(name: &[u8], static_key: X25519KeyPair, listen_port: u16) -> Self {
811 let mut name_buf = [0u8; 16];
812 let copy_len = core::cmp::min(name.len(), 15);
813 name_buf[..copy_len].copy_from_slice(&name[..copy_len]);
814 Self {
815 name: name_buf,
816 static_key,
817 listen_port,
818 tunnel_address: None,
819 tunnel_prefix: 24,
820 peers: BTreeMap::new(),
821 mtu: 1420, is_up: false,
823 next_sender_index: 1,
824 packets_in: 0,
825 packets_out: 0,
826 }
827 }
828
829 pub fn set_address(&mut self, addr: IpAddress, prefix: u8) {
831 self.tunnel_address = Some(addr);
832 self.tunnel_prefix = prefix;
833 }
834
835 pub fn calculate_mtu(outer_mtu: u16, is_ipv6: bool) -> u16 {
842 let overhead = if is_ipv6 { 80u16 } else { 60u16 };
843 outer_mtu.saturating_sub(overhead)
844 }
845
846 fn peer_key_hash(public_key: &[u8; 32]) -> u64 {
848 let hash = blake2s(public_key, 32);
849 u64::from_le_bytes([
850 hash[0], hash[1], hash[2], hash[3], hash[4], hash[5], hash[6], hash[7],
851 ])
852 }
853
854 pub fn add_peer(&mut self, peer: WireGuardPeer) -> Result<(), WireGuardError> {
856 let key = Self::peer_key_hash(&peer.public_key);
857 self.peers.insert(key, peer);
858 Ok(())
859 }
860
861 pub fn remove_peer(&mut self, public_key: &[u8; 32]) -> Result<(), WireGuardError> {
863 let key = Self::peer_key_hash(public_key);
864 self.peers
865 .remove(&key)
866 .map(|_| ())
867 .ok_or(WireGuardError::PeerNotFound)
868 }
869
870 pub fn get_peer(&self, public_key: &[u8; 32]) -> Option<&WireGuardPeer> {
872 let key = Self::peer_key_hash(public_key);
873 self.peers.get(&key)
874 }
875
876 pub fn get_peer_mut(&mut self, public_key: &[u8; 32]) -> Option<&mut WireGuardPeer> {
878 let key = Self::peer_key_hash(public_key);
879 self.peers.get_mut(&key)
880 }
881
882 pub fn find_peer_for_ip(&self, dst: &Ipv4Address) -> Option<&WireGuardPeer> {
884 self.peers.values().find(|peer| peer.is_allowed(dst))
885 }
886
887 pub fn up(&mut self) -> Result<(), WireGuardError> {
889 if self.tunnel_address.is_none() {
890 return Err(WireGuardError::NotConfigured);
891 }
892 self.is_up = true;
893 Ok(())
894 }
895
896 pub fn down(&mut self) {
898 self.is_up = false;
899 }
900
901 pub fn alloc_sender_index(&mut self) -> u32 {
903 let idx = self.next_sender_index;
904 self.next_sender_index = self.next_sender_index.wrapping_add(1);
905 if self.next_sender_index == 0 {
906 self.next_sender_index = 1;
907 }
908 idx
909 }
910
911 pub fn peer_count(&self) -> usize {
913 self.peers.len()
914 }
915}
916
917#[derive(Debug, Clone, Copy, PartialEq, Eq)]
921pub enum TimerEvent {
922 RekeyInitiate,
924 HandshakeRetry,
926 SessionExpiry,
928 DeadPeer,
930 Keepalive,
932}
933
934pub struct PeerTimers {
936 pub handshake_initiated_ms: u64,
938 pub last_keepalive_sent: u64,
940 pub rekey_pending: bool,
942}
943
944impl Default for PeerTimers {
945 fn default() -> Self {
946 Self::new()
947 }
948}
949
950impl PeerTimers {
951 pub fn new() -> Self {
952 Self {
953 handshake_initiated_ms: 0,
954 last_keepalive_sent: 0,
955 rekey_pending: false,
956 }
957 }
958}
959
960pub fn check_peer_timers(
962 peer: &WireGuardPeer,
963 timers: &PeerTimers,
964 now_secs: u64,
965 now_ms: u64,
966) -> Option<TimerEvent> {
967 if peer.is_session_expired(now_secs) {
969 return Some(TimerEvent::SessionExpiry);
970 }
971
972 if let Some(ref session) = peer.session {
974 if session.needs_rekey(now_secs) && !timers.rekey_pending {
975 return Some(TimerEvent::RekeyInitiate);
976 }
977 }
978
979 if peer.handshake_state == HandshakeState::InitSent {
981 if peer.handshake_retries >= MAX_HANDSHAKE_RETRIES {
982 return Some(TimerEvent::DeadPeer);
983 }
984 if now_ms >= peer.next_retry_ms && peer.next_retry_ms > 0 {
985 return Some(TimerEvent::HandshakeRetry);
986 }
987 }
988
989 if peer.needs_keepalive(now_secs) {
991 return Some(TimerEvent::Keepalive);
992 }
993
994 None
995}
996
997#[cfg(test)]
1000mod tests {
1001 #[allow(unused_imports)]
1002 use alloc::vec;
1003
1004 use super::*;
1005
1006 #[test]
1009 fn test_blake2s_empty_input() {
1010 let hash = blake2s(b"", 32);
1012 assert_eq!(hash[0], 0x69);
1014 assert_eq!(hash[1], 0x21);
1015 assert_eq!(hash[2], 0x7a);
1016 assert_eq!(hash[3], 0x30);
1017 }
1018
1019 #[test]
1020 fn test_blake2s_abc() {
1021 let hash = blake2s(b"abc", 32);
1023 assert_eq!(hash[0], 0x50);
1024 assert_eq!(hash[1], 0x8C);
1025 assert!(hash.iter().any(|&b| b != 0));
1027 }
1028
1029 #[test]
1030 fn test_blake2s_deterministic() {
1031 let h1 = blake2s(b"test data", 32);
1032 let h2 = blake2s(b"test data", 32);
1033 assert_eq!(h1, h2);
1034 }
1035
1036 #[test]
1037 fn test_blake2s_different_inputs() {
1038 let h1 = blake2s(b"hello", 32);
1039 let h2 = blake2s(b"world", 32);
1040 assert_ne!(h1, h2);
1041 }
1042
1043 #[test]
1044 fn test_blake2s_keyed_mode() {
1045 let key = [0x42u8; 32];
1046 let h1 = blake2s_keyed(&key, b"data", 32);
1047 let h2 = blake2s(b"data", 32);
1048 assert_ne!(h1, h2);
1050 let h3 = blake2s_keyed(&key, b"data", 32);
1052 assert_eq!(h1, h3);
1053 }
1054
1055 #[test]
1056 fn test_blake2s_keyed_different_keys() {
1057 let key1 = [0x01u8; 32];
1058 let key2 = [0x02u8; 32];
1059 let h1 = blake2s_keyed(&key1, b"data", 32);
1060 let h2 = blake2s_keyed(&key2, b"data", 32);
1061 assert_ne!(h1, h2);
1062 }
1063
1064 #[test]
1065 fn test_hmac_blake2s() {
1066 let key = [0xABu8; 32];
1067 let mac1 = hmac_blake2s(&key, b"message");
1068 let mac2 = hmac_blake2s(&key, b"message");
1069 assert_eq!(mac1, mac2);
1070
1071 let mac3 = hmac_blake2s(&key, b"different message");
1072 assert_ne!(mac1, mac3);
1073 }
1074
1075 #[test]
1078 fn test_replay_window_accept_new() {
1079 let mut window = AntiReplayWindow::new();
1080 assert!(window.check(0));
1081 window.update(0);
1082 assert!(window.check(1));
1083 window.update(1);
1084 assert!(window.check(2));
1085 window.update(2);
1086 assert!(window.check(100));
1087 }
1088
1089 #[test]
1090 fn test_replay_window_reject_duplicate() {
1091 let mut window = AntiReplayWindow::new();
1092 window.update(5);
1093 assert!(!window.check(5)); }
1095
1096 #[test]
1097 fn test_replay_window_reject_old() {
1098 let mut window = AntiReplayWindow::new();
1099 window.update(3000);
1100 assert!(!window.check(0));
1102 }
1103
1104 #[test]
1105 fn test_replay_window_accept_within_window() {
1106 let mut window = AntiReplayWindow::new();
1107 window.update(100);
1108 assert!(window.check(99));
1110 window.update(99);
1111 assert!(!window.check(99));
1113 }
1114
1115 #[test]
1116 fn test_replay_window_large_jump() {
1117 let mut window = AntiReplayWindow::new();
1118 window.update(0);
1119 window.update(10000);
1120 assert!(!window.check(0));
1122 assert!(!window.check(100));
1123 assert!(window.check(10001));
1125 }
1126
1127 #[test]
1130 fn test_peer_add_remove() {
1131 let seed = [1u8; 32];
1132 let key = X25519KeyPair::from_seed(&seed);
1133 let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1134
1135 let peer_pub = [0x42u8; 32];
1136 let peer = WireGuardPeer::new(peer_pub);
1137 assert!(iface.add_peer(peer).is_ok());
1138 assert_eq!(iface.peer_count(), 1);
1139
1140 assert!(iface.get_peer(&peer_pub).is_some());
1141 assert!(iface.remove_peer(&peer_pub).is_ok());
1142 assert_eq!(iface.peer_count(), 0);
1143 assert!(iface.get_peer(&peer_pub).is_none());
1144 }
1145
1146 #[test]
1147 fn test_peer_remove_not_found() {
1148 let seed = [1u8; 32];
1149 let key = X25519KeyPair::from_seed(&seed);
1150 let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1151
1152 let fake_pub = [0xFFu8; 32];
1153 assert_eq!(
1154 iface.remove_peer(&fake_pub),
1155 Err(WireGuardError::PeerNotFound)
1156 );
1157 }
1158
1159 #[test]
1160 fn test_peer_lookup() {
1161 let seed = [1u8; 32];
1162 let key = X25519KeyPair::from_seed(&seed);
1163 let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1164
1165 let pub1 = [0x01u8; 32];
1166 let pub2 = [0x02u8; 32];
1167 iface.add_peer(WireGuardPeer::new(pub1)).unwrap();
1168 iface.add_peer(WireGuardPeer::new(pub2)).unwrap();
1169
1170 assert!(iface.get_peer(&pub1).is_some());
1171 assert!(iface.get_peer(&pub2).is_some());
1172 assert_eq!(iface.peer_count(), 2);
1173 }
1174
1175 #[test]
1178 fn test_key_rotation_by_time() {
1179 let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1180 assert!(!keys.needs_rekey(0));
1181 assert!(!keys.needs_rekey(119));
1182 assert!(keys.needs_rekey(120)); assert!(keys.needs_rekey(200));
1184 }
1185
1186 #[test]
1187 fn test_key_rotation_by_messages() {
1188 let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1189 keys.messages_sent
1191 .store(REKEY_AFTER_MESSAGES, Ordering::Relaxed);
1192 assert!(keys.needs_rekey(0));
1193 }
1194
1195 #[test]
1198 fn test_mtu_calculation_ipv4() {
1199 assert_eq!(WireGuardInterface::calculate_mtu(1500, false), 1440);
1201 }
1202
1203 #[test]
1204 fn test_mtu_calculation_ipv6() {
1205 assert_eq!(WireGuardInterface::calculate_mtu(1500, true), 1420);
1207 }
1208
1209 #[test]
1210 fn test_mtu_calculation_small() {
1211 assert_eq!(WireGuardInterface::calculate_mtu(50, false), 0);
1213 assert_eq!(WireGuardInterface::calculate_mtu(60, false), 0);
1214 assert_eq!(WireGuardInterface::calculate_mtu(61, false), 1);
1215 }
1216
1217 #[test]
1220 fn test_nonce_counter_increment() {
1221 let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1222 assert_eq!(keys.next_nonce().unwrap(), 0);
1223 assert_eq!(keys.next_nonce().unwrap(), 1);
1224 assert_eq!(keys.next_nonce().unwrap(), 2);
1225 }
1226
1227 #[test]
1228 fn test_nonce_counter_overflow() {
1229 let keys = SessionKeys::new([1u8; 32], [2u8; 32], 0);
1230 keys.sending_nonce
1231 .store(REKEY_AFTER_MESSAGES, Ordering::Relaxed);
1232 assert_eq!(keys.next_nonce(), Err(WireGuardError::NonceOverflow));
1233 }
1234
1235 #[test]
1238 fn test_allowed_ip_exact_match() {
1239 let aip = AllowedIp::new(Ipv4Address::new(10, 0, 0, 1), 32);
1240 assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1241 assert!(!aip.matches(&Ipv4Address::new(10, 0, 0, 2)));
1242 }
1243
1244 #[test]
1245 fn test_allowed_ip_subnet_match() {
1246 let aip = AllowedIp::new(Ipv4Address::new(10, 0, 0, 0), 24);
1247 assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1248 assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 254)));
1249 assert!(!aip.matches(&Ipv4Address::new(10, 0, 1, 1)));
1250 }
1251
1252 #[test]
1253 fn test_allowed_ip_wildcard() {
1254 let aip = AllowedIp::new(Ipv4Address::new(0, 0, 0, 0), 0);
1255 assert!(aip.matches(&Ipv4Address::new(192, 168, 1, 1)));
1256 assert!(aip.matches(&Ipv4Address::new(10, 0, 0, 1)));
1257 }
1258
1259 #[test]
1262 fn test_session_state_transitions() {
1263 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1264 assert_eq!(peer.handshake_state, HandshakeState::None);
1265
1266 peer.handshake_state = HandshakeState::InitSent;
1267 assert_eq!(peer.handshake_state, HandshakeState::InitSent);
1268
1269 peer.handshake_state = HandshakeState::Established;
1270 assert_eq!(peer.handshake_state, HandshakeState::Established);
1271 }
1272
1273 #[test]
1274 fn test_session_expiry() {
1275 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1276 peer.last_received = 100;
1277 peer.last_sent = 100;
1278
1279 assert!(!peer.is_session_expired(200)); assert!(peer.is_session_expired(281)); }
1282
1283 #[test]
1284 fn test_handshake_retry_backoff() {
1285 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1286 peer.handshake_retries = 0;
1287 assert_eq!(peer.retry_delay_ms(), 1000);
1288
1289 peer.handshake_retries = 1;
1290 assert_eq!(peer.retry_delay_ms(), 2000);
1291
1292 peer.handshake_retries = 2;
1293 assert_eq!(peer.retry_delay_ms(), 4000);
1294
1295 peer.handshake_retries = 3;
1296 assert_eq!(peer.retry_delay_ms(), 8000);
1297
1298 peer.handshake_retries = 4;
1299 assert_eq!(peer.retry_delay_ms(), 16000);
1300
1301 peer.handshake_retries = MAX_HANDSHAKE_RETRIES;
1303 assert_eq!(peer.retry_delay_ms(), 0);
1304 }
1305
1306 #[test]
1309 fn test_handshake_initiation_size() {
1310 let static_key = X25519KeyPair::from_seed(&[0x11u8; 32]);
1311 let remote_pub = [0x22u8; 32];
1312 let psk = [0x33u8; 32];
1313 let timestamp = [0u8; 12];
1314
1315 let mut ctx = HandshakeContext::new();
1316 let msg = ctx.create_initiation(&static_key, &remote_pub, &psk, ×tamp, 1);
1317
1318 assert_eq!(msg.len(), HANDSHAKE_INIT_SIZE);
1319 assert_eq!(msg[0], MSG_HANDSHAKE_INIT);
1320 assert_eq!(u32::from_le_bytes([msg[4], msg[5], msg[6], msg[7]]), 1);
1322 }
1323
1324 #[test]
1325 fn test_handshake_response_size() {
1326 let static_key = X25519KeyPair::from_seed(&[0x44u8; 32]);
1327 let remote_pub = [0x55u8; 32];
1328 let psk = [0x66u8; 32];
1329
1330 let mut ctx = HandshakeContext::new();
1331 let msg = ctx.create_response(&static_key, &remote_pub, &psk, 2, 1);
1332
1333 assert_eq!(msg.len(), HANDSHAKE_RESP_SIZE);
1334 assert_eq!(msg[0], MSG_HANDSHAKE_RESP);
1335 assert_eq!(u32::from_le_bytes([msg[4], msg[5], msg[6], msg[7]]), 2);
1337 assert_eq!(u32::from_le_bytes([msg[8], msg[9], msg[10], msg[11]]), 1);
1339 }
1340
1341 #[test]
1344 fn test_transport_encrypt_decrypt() {
1345 let send_key = [0xAAu8; 32];
1346 let recv_key = [0xAAu8; 32]; let send_session = SessionKeys::new(send_key, [0u8; 32], 0);
1348 let recv_session = SessionKeys::new([0u8; 32], recv_key, 0);
1349
1350 let payload = b"hello wireguard";
1351 let encrypted = encrypt_transport(&send_session, 42, payload).unwrap();
1352
1353 assert_eq!(encrypted[0], MSG_TRANSPORT_DATA);
1355 assert_eq!(
1356 u32::from_le_bytes([encrypted[4], encrypted[5], encrypted[6], encrypted[7]]),
1357 42
1358 );
1359
1360 let mut window = AntiReplayWindow::new();
1361 let decrypted = decrypt_transport(&recv_session, &mut window, &encrypted).unwrap();
1362 assert!(decrypted.len() >= payload.len());
1364 assert_eq!(&decrypted[..payload.len()], payload);
1365 }
1366
1367 #[test]
1370 fn test_keepalive_timing() {
1371 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1372 peer.handshake_state = HandshakeState::Established;
1373 peer.keepalive_interval = 25;
1374 peer.last_sent = 100;
1375
1376 assert!(!peer.needs_keepalive(120)); assert!(peer.needs_keepalive(125)); assert!(peer.needs_keepalive(200)); }
1380
1381 #[test]
1382 fn test_keepalive_disabled() {
1383 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1384 peer.handshake_state = HandshakeState::Established;
1385 peer.keepalive_interval = 0;
1386 peer.last_sent = 0;
1387
1388 assert!(!peer.needs_keepalive(1000));
1389 }
1390
1391 #[test]
1392 fn test_timer_event_session_expiry() {
1393 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1394 peer.last_received = 100;
1395 peer.last_sent = 100;
1396 let timers = PeerTimers::new();
1397
1398 let event = check_peer_timers(&peer, &timers, 300, 300_000);
1399 assert_eq!(event, Some(TimerEvent::SessionExpiry));
1400 }
1401
1402 #[test]
1403 fn test_timer_event_dead_peer() {
1404 let mut peer = WireGuardPeer::new([0x01u8; 32]);
1405 peer.handshake_state = HandshakeState::InitSent;
1406 peer.handshake_retries = MAX_HANDSHAKE_RETRIES;
1407 let timers = PeerTimers::new();
1408
1409 let event = check_peer_timers(&peer, &timers, 0, 0);
1410 assert_eq!(event, Some(TimerEvent::DeadPeer));
1411 }
1412
1413 #[test]
1414 fn test_interface_up_down() {
1415 let seed = [1u8; 32];
1416 let key = X25519KeyPair::from_seed(&seed);
1417 let mut iface = WireGuardInterface::new(b"wg0", key, DEFAULT_PORT);
1418
1419 assert_eq!(iface.up(), Err(WireGuardError::NotConfigured));
1421
1422 iface.set_address(IpAddress::V4(Ipv4Address::new(10, 0, 0, 1)), 24);
1423 assert!(iface.up().is_ok());
1424 assert!(iface.is_up);
1425
1426 iface.down();
1427 assert!(!iface.is_up);
1428 }
1429}