⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/firewall/
nat.rs

1//! Network Address Translation (NAT) engine
2//!
3//! Supports SNAT (source NAT), DNAT (destination NAT), and IP masquerading.
4//! Uses a port pool with bitmap allocation for ephemeral port assignment.
5//! Implements RFC 1624 incremental checksum updates to avoid full
6//! recalculation when translating addresses.
7
8#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12
13use super::conntrack::{ConntrackKey, NatInfo};
14use crate::{
15    error::KernelError,
16    net::{Ipv4Address, Port},
17    sync::once_lock::GlobalState,
18};
19
20// ============================================================================
21// Constants
22// ============================================================================
23
24/// Start of ephemeral port range for NAT
25const PORT_POOL_START: u16 = 49152;
26
27/// End of ephemeral port range for NAT (inclusive)
28const PORT_POOL_END: u16 = 65535;
29
30/// Total ports in the pool
31const PORT_POOL_SIZE: usize = (PORT_POOL_END - PORT_POOL_START + 1) as usize;
32
33/// Number of u64 words needed for the port bitmap
34const BITMAP_WORDS: usize = PORT_POOL_SIZE.div_ceil(64);
35
36// ============================================================================
37// NAT Type
38// ============================================================================
39
40/// Type of NAT translation
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum NatType {
43    /// Source NAT: rewrite source address/port
44    Snat,
45    /// Destination NAT: rewrite destination address/port
46    Dnat,
47    /// Masquerade: SNAT using the outgoing interface address
48    Masquerade,
49}
50
51// ============================================================================
52// NAT Mapping
53// ============================================================================
54
55/// A single NAT mapping recording original and translated addresses
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct NatMapping {
58    /// NAT type
59    pub nat_type: NatType,
60    /// Original source address
61    pub original_src_ip: Ipv4Address,
62    /// Original source port
63    pub original_src_port: Port,
64    /// Original destination address
65    pub original_dst_ip: Ipv4Address,
66    /// Original destination port
67    pub original_dst_port: Port,
68    /// Translated source address
69    pub translated_src_ip: Ipv4Address,
70    /// Translated source port
71    pub translated_src_port: Port,
72    /// Translated destination address
73    pub translated_dst_ip: Ipv4Address,
74    /// Translated destination port
75    pub translated_dst_port: Port,
76}
77
78impl NatMapping {
79    /// Convert to a NatInfo for conntrack
80    pub fn to_nat_info(&self) -> NatInfo {
81        NatInfo {
82            original_src_ip: self.original_src_ip,
83            original_src_port: self.original_src_port,
84            translated_src_ip: self.translated_src_ip,
85            translated_src_port: self.translated_src_port,
86            original_dst_ip: self.original_dst_ip,
87            original_dst_port: self.original_dst_port,
88            translated_dst_ip: self.translated_dst_ip,
89            translated_dst_port: self.translated_dst_port,
90        }
91    }
92}
93
94// ============================================================================
95// Port Pool
96// ============================================================================
97
98/// Bitmap-based ephemeral port allocator for NAT
99///
100/// Manages ports in the range 49152-65535 using a compact bitmap.
101/// Each bit represents one port: 0 = free, 1 = allocated.
102pub struct PortPool {
103    /// Bitmap of allocated ports (bit N = port PORT_POOL_START + N)
104    bitmap: [u64; BITMAP_WORDS],
105    /// Number of currently allocated ports
106    allocated_count: u16,
107}
108
109impl PortPool {
110    /// Create a new port pool with all ports available
111    pub fn new() -> Self {
112        Self {
113            bitmap: [0u64; BITMAP_WORDS],
114            allocated_count: 0,
115        }
116    }
117
118    /// Number of ports currently allocated
119    pub fn allocated(&self) -> u16 {
120        self.allocated_count
121    }
122
123    /// Number of ports available
124    pub fn available(&self) -> u16 {
125        PORT_POOL_SIZE as u16 - self.allocated_count
126    }
127
128    /// Allocate the next available port
129    ///
130    /// Returns the allocated port number or None if pool is exhausted.
131    pub fn allocate(&mut self) -> Option<Port> {
132        for (word_idx, word) in self.bitmap.iter_mut().enumerate() {
133            if *word == u64::MAX {
134                continue; // All bits set in this word
135            }
136            // Find first zero bit
137            let bit_idx = (!*word).trailing_zeros() as usize;
138            let port_offset = word_idx * 64 + bit_idx;
139            if port_offset >= PORT_POOL_SIZE {
140                return None;
141            }
142            *word |= 1u64 << bit_idx;
143            self.allocated_count += 1;
144            return Some(PORT_POOL_START + port_offset as u16);
145        }
146        None
147    }
148
149    /// Release a previously allocated port
150    pub fn release(&mut self, port: Port) -> bool {
151        if !(PORT_POOL_START..=PORT_POOL_END).contains(&port) {
152            return false;
153        }
154        let offset = (port - PORT_POOL_START) as usize;
155        let word_idx = offset / 64;
156        let bit_idx = offset % 64;
157        if self.bitmap[word_idx] & (1u64 << bit_idx) != 0 {
158            self.bitmap[word_idx] &= !(1u64 << bit_idx);
159            self.allocated_count -= 1;
160            true
161        } else {
162            false // Port was not allocated
163        }
164    }
165
166    /// Check if a specific port is allocated
167    pub fn is_allocated(&self, port: Port) -> bool {
168        if !(PORT_POOL_START..=PORT_POOL_END).contains(&port) {
169            return false;
170        }
171        let offset = (port - PORT_POOL_START) as usize;
172        let word_idx = offset / 64;
173        let bit_idx = offset % 64;
174        self.bitmap[word_idx] & (1u64 << bit_idx) != 0
175    }
176}
177
178impl Default for PortPool {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184// ============================================================================
185// Incremental Checksum Update (RFC 1624)
186// ============================================================================
187
188/// Incrementally update a one's complement checksum when a 16-bit value
189/// changes.
190///
191/// Implements the algorithm from RFC 1624:
192///   HC' = ~(~HC + ~m + m')
193///
194/// where HC is the old checksum, m is the old value, and m' is the new value.
195/// All arithmetic is one's complement (16-bit with end-around carry).
196pub fn update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
197    // Work in u32 to handle carry
198    let hc = !old_checksum as u32;
199    let m = !old_value as u32;
200    let m_prime = new_value as u32;
201
202    let mut sum = hc + m + m_prime;
203
204    // Fold carry bits
205    while sum > 0xFFFF {
206        sum = (sum & 0xFFFF) + (sum >> 16);
207    }
208
209    !sum as u16
210}
211
212/// Update checksum for a 32-bit (IP address) change, processing two 16-bit
213/// halves
214pub fn update_checksum_32(old_checksum: u16, old_addr: u32, new_addr: u32) -> u16 {
215    let old_hi = (old_addr >> 16) as u16;
216    let old_lo = old_addr as u16;
217    let new_hi = (new_addr >> 16) as u16;
218    let new_lo = new_addr as u16;
219
220    let c1 = update_checksum(old_checksum, old_hi, new_hi);
221    update_checksum(c1, old_lo, new_lo)
222}
223
224// ============================================================================
225// NAT Engine
226// ============================================================================
227
228/// The NAT engine managing translations and port allocation
229pub struct NatEngine {
230    /// Ephemeral port pool
231    pub port_pool: PortPool,
232    /// Active NAT mappings indexed by original connection key
233    pub mappings: BTreeMap<ConntrackKey, NatMapping>,
234    /// Masquerade address (outgoing interface IP)
235    pub masquerade_addr: Ipv4Address,
236    /// Total translations performed
237    pub total_translations: u64,
238}
239
240impl NatEngine {
241    /// Create a new NAT engine
242    pub fn new() -> Self {
243        Self {
244            port_pool: PortPool::new(),
245            mappings: BTreeMap::new(),
246            masquerade_addr: Ipv4Address::ANY,
247            total_translations: 0,
248        }
249    }
250
251    /// Set the masquerade (outgoing interface) address
252    pub fn set_masquerade_addr(&mut self, addr: Ipv4Address) {
253        self.masquerade_addr = addr;
254    }
255
256    /// Translate an outbound packet with SNAT
257    ///
258    /// Rewrites the source address and allocates a new source port.
259    /// Returns the NAT mapping on success.
260    pub fn translate_outbound_snat(
261        &mut self,
262        key: &ConntrackKey,
263        new_src_ip: Ipv4Address,
264    ) -> Option<NatMapping> {
265        // Check for existing mapping
266        if let Some(mapping) = self.mappings.get(key) {
267            return Some(*mapping);
268        }
269
270        // Allocate a new port
271        let new_port = self.port_pool.allocate()?;
272
273        let mapping = NatMapping {
274            nat_type: NatType::Snat,
275            original_src_ip: key.src_ip,
276            original_src_port: key.src_port,
277            original_dst_ip: key.dst_ip,
278            original_dst_port: key.dst_port,
279            translated_src_ip: new_src_ip,
280            translated_src_port: new_port,
281            translated_dst_ip: key.dst_ip,
282            translated_dst_port: key.dst_port,
283        };
284
285        self.mappings.insert(*key, mapping);
286        self.total_translations += 1;
287        Some(mapping)
288    }
289
290    /// Translate an outbound packet with masquerading
291    ///
292    /// Uses the configured masquerade address as the source.
293    pub fn translate_outbound_masquerade(&mut self, key: &ConntrackKey) -> Option<NatMapping> {
294        let addr = self.masquerade_addr;
295        if addr == Ipv4Address::ANY {
296            return None;
297        }
298
299        // Check for existing mapping
300        if let Some(mapping) = self.mappings.get(key) {
301            return Some(*mapping);
302        }
303
304        let new_port = self.port_pool.allocate()?;
305
306        let mapping = NatMapping {
307            nat_type: NatType::Masquerade,
308            original_src_ip: key.src_ip,
309            original_src_port: key.src_port,
310            original_dst_ip: key.dst_ip,
311            original_dst_port: key.dst_port,
312            translated_src_ip: addr,
313            translated_src_port: new_port,
314            translated_dst_ip: key.dst_ip,
315            translated_dst_port: key.dst_port,
316        };
317
318        self.mappings.insert(*key, mapping);
319        self.total_translations += 1;
320        Some(mapping)
321    }
322
323    /// Translate an inbound packet with DNAT
324    ///
325    /// Rewrites the destination address and port.
326    pub fn translate_inbound_dnat(
327        &mut self,
328        key: &ConntrackKey,
329        new_dst_ip: Ipv4Address,
330        new_dst_port: Port,
331    ) -> Option<NatMapping> {
332        // Check for existing mapping
333        if let Some(mapping) = self.mappings.get(key) {
334            return Some(*mapping);
335        }
336
337        let mapping = NatMapping {
338            nat_type: NatType::Dnat,
339            original_src_ip: key.src_ip,
340            original_src_port: key.src_port,
341            original_dst_ip: key.dst_ip,
342            original_dst_port: key.dst_port,
343            translated_src_ip: key.src_ip,
344            translated_src_port: key.src_port,
345            translated_dst_ip: new_dst_ip,
346            translated_dst_port: new_dst_port,
347        };
348
349        self.mappings.insert(*key, mapping);
350        self.total_translations += 1;
351        Some(mapping)
352    }
353
354    /// Look up a reverse NAT mapping for inbound reply traffic
355    ///
356    /// Given a reply packet's key, find the corresponding SNAT/masquerade
357    /// mapping to reverse the translation.
358    pub fn lookup_reverse(&self, reply_key: &ConntrackKey) -> Option<&NatMapping> {
359        // For SNAT/masquerade, the reply destination is our translated source.
360        // We need to find the original mapping where:
361        //   translated_src_ip == reply_key.dst_ip
362        //   translated_src_port == reply_key.dst_port
363        for mapping in self.mappings.values() {
364            match mapping.nat_type {
365                NatType::Snat | NatType::Masquerade => {
366                    if mapping.translated_src_ip == reply_key.dst_ip
367                        && mapping.translated_src_port == reply_key.dst_port
368                        && mapping.original_dst_ip == reply_key.src_ip
369                    {
370                        return Some(mapping);
371                    }
372                }
373                NatType::Dnat => {
374                    if mapping.translated_dst_ip == reply_key.src_ip
375                        && mapping.translated_dst_port == reply_key.src_port
376                        && mapping.original_src_ip == reply_key.dst_ip
377                    {
378                        return Some(mapping);
379                    }
380                }
381            }
382        }
383        None
384    }
385
386    /// Remove a NAT mapping and release its allocated port
387    pub fn remove_mapping(&mut self, key: &ConntrackKey) -> Option<NatMapping> {
388        if let Some(mapping) = self.mappings.remove(key) {
389            // Release port for SNAT/masquerade
390            match mapping.nat_type {
391                NatType::Snat | NatType::Masquerade => {
392                    self.port_pool.release(mapping.translated_src_port);
393                }
394                NatType::Dnat => {}
395            }
396            Some(mapping)
397        } else {
398            None
399        }
400    }
401
402    /// Number of active mappings
403    pub fn mapping_count(&self) -> usize {
404        self.mappings.len()
405    }
406}
407
408impl Default for NatEngine {
409    fn default() -> Self {
410        Self::new()
411    }
412}
413
414// ============================================================================
415// Global State
416// ============================================================================
417
418static NAT_ENGINE: GlobalState<spin::Mutex<NatEngine>> = GlobalState::new();
419
420/// Initialize the NAT engine
421pub fn init() -> Result<(), KernelError> {
422    NAT_ENGINE
423        .init(spin::Mutex::new(NatEngine::new()))
424        .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
425    Ok(())
426}
427
428/// Access the global NAT engine
429pub fn with_nat<R, F: FnOnce(&mut NatEngine) -> R>(f: F) -> Option<R> {
430    NAT_ENGINE.with(|lock| {
431        let mut engine = lock.lock();
432        f(&mut engine)
433    })
434}
435
436// ============================================================================
437// Tests
438// ============================================================================
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_port_pool_allocate() {
446        let mut pool = PortPool::new();
447        let port = pool.allocate().unwrap();
448        assert_eq!(port, PORT_POOL_START);
449        assert_eq!(pool.allocated(), 1);
450        assert!(pool.is_allocated(port));
451    }
452
453    #[test]
454    fn test_port_pool_release() {
455        let mut pool = PortPool::new();
456        let port = pool.allocate().unwrap();
457        assert!(pool.release(port));
458        assert_eq!(pool.allocated(), 0);
459        assert!(!pool.is_allocated(port));
460    }
461
462    #[test]
463    fn test_port_pool_release_invalid() {
464        let mut pool = PortPool::new();
465        assert!(!pool.release(80)); // Below range
466        assert!(!pool.release(PORT_POOL_START)); // Not allocated
467    }
468
469    #[test]
470    fn test_port_pool_sequential_allocation() {
471        let mut pool = PortPool::new();
472        let p1 = pool.allocate().unwrap();
473        let p2 = pool.allocate().unwrap();
474        let p3 = pool.allocate().unwrap();
475        assert_eq!(p1, PORT_POOL_START);
476        assert_eq!(p2, PORT_POOL_START + 1);
477        assert_eq!(p3, PORT_POOL_START + 2);
478        assert_eq!(pool.allocated(), 3);
479    }
480
481    #[test]
482    fn test_port_pool_reuse_released() {
483        let mut pool = PortPool::new();
484        let p1 = pool.allocate().unwrap();
485        let _p2 = pool.allocate().unwrap();
486        pool.release(p1);
487        let p3 = pool.allocate().unwrap();
488        assert_eq!(p3, p1); // Should reuse first available
489    }
490
491    #[test]
492    fn test_checksum_update_identity() {
493        // If old == new, checksum should not change
494        let checksum = 0x1234;
495        let result = update_checksum(checksum, 0xABCD, 0xABCD);
496        assert_eq!(result, checksum);
497    }
498
499    #[test]
500    fn test_checksum_update_basic() {
501        // Known test: old_checksum = 0xDD2F, change 0x5555 -> 0x3285
502        // Expected: 0x0000 (from RFC 1624 example adapted)
503        let result = update_checksum(0x0000, 0x5555, 0x5555);
504        assert_eq!(result, 0x0000); // Identity
505    }
506
507    #[test]
508    fn test_checksum_update_32_identity() {
509        let checksum = 0xABCD;
510        let addr: u32 = 0xC0A80101; // 192.168.1.1
511        let result = update_checksum_32(checksum, addr, addr);
512        assert_eq!(result, checksum);
513    }
514
515    #[test]
516    fn test_nat_engine_snat() {
517        let mut engine = NatEngine::new();
518        let key = ConntrackKey::new(
519            Ipv4Address::new(192, 168, 1, 100),
520            Ipv4Address::new(8, 8, 8, 8),
521            12345,
522            53,
523            ConntrackKey::PROTO_UDP,
524        );
525        let public_ip = Ipv4Address::new(203, 0, 113, 1);
526
527        let mapping = engine.translate_outbound_snat(&key, public_ip).unwrap();
528        assert_eq!(mapping.nat_type, NatType::Snat);
529        assert_eq!(mapping.original_src_ip, Ipv4Address::new(192, 168, 1, 100));
530        assert_eq!(mapping.translated_src_ip, public_ip);
531        assert!(mapping.translated_src_port >= PORT_POOL_START);
532        assert_eq!(engine.mapping_count(), 1);
533    }
534
535    #[test]
536    fn test_nat_engine_masquerade() {
537        let mut engine = NatEngine::new();
538        engine.set_masquerade_addr(Ipv4Address::new(203, 0, 113, 1));
539        let key = ConntrackKey::new(
540            Ipv4Address::new(192, 168, 1, 50),
541            Ipv4Address::new(1, 1, 1, 1),
542            5000,
543            443,
544            ConntrackKey::PROTO_TCP,
545        );
546
547        let mapping = engine.translate_outbound_masquerade(&key).unwrap();
548        assert_eq!(mapping.nat_type, NatType::Masquerade);
549        assert_eq!(mapping.translated_src_ip, Ipv4Address::new(203, 0, 113, 1));
550    }
551
552    #[test]
553    fn test_nat_engine_masquerade_no_addr() {
554        let mut engine = NatEngine::new();
555        // masquerade_addr is ANY (default)
556        let key = ConntrackKey::new(
557            Ipv4Address::new(192, 168, 1, 50),
558            Ipv4Address::new(1, 1, 1, 1),
559            5000,
560            443,
561            ConntrackKey::PROTO_TCP,
562        );
563        assert!(engine.translate_outbound_masquerade(&key).is_none());
564    }
565
566    #[test]
567    fn test_nat_engine_dnat() {
568        let mut engine = NatEngine::new();
569        let key = ConntrackKey::new(
570            Ipv4Address::new(8, 8, 8, 8),
571            Ipv4Address::new(203, 0, 113, 1),
572            5000,
573            80,
574            ConntrackKey::PROTO_TCP,
575        );
576        let internal_ip = Ipv4Address::new(192, 168, 1, 10);
577
578        let mapping = engine
579            .translate_inbound_dnat(&key, internal_ip, 8080)
580            .unwrap();
581        assert_eq!(mapping.nat_type, NatType::Dnat);
582        assert_eq!(mapping.translated_dst_ip, internal_ip);
583        assert_eq!(mapping.translated_dst_port, 8080);
584    }
585
586    #[test]
587    fn test_nat_engine_remove_mapping() {
588        let mut engine = NatEngine::new();
589        let key = ConntrackKey::new(
590            Ipv4Address::new(192, 168, 1, 100),
591            Ipv4Address::new(8, 8, 8, 8),
592            12345,
593            53,
594            ConntrackKey::PROTO_UDP,
595        );
596        let public_ip = Ipv4Address::new(203, 0, 113, 1);
597
598        let mapping = engine.translate_outbound_snat(&key, public_ip).unwrap();
599        let allocated_port = mapping.translated_src_port;
600        assert!(engine.port_pool.is_allocated(allocated_port));
601
602        engine.remove_mapping(&key);
603        assert_eq!(engine.mapping_count(), 0);
604        assert!(!engine.port_pool.is_allocated(allocated_port));
605    }
606}