1#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12
13use super::conntrack::{ConntrackKey, NatInfo};
14use crate::{
15 error::KernelError,
16 net::{Ipv4Address, Port},
17 sync::once_lock::GlobalState,
18};
19
20const PORT_POOL_START: u16 = 49152;
26
27const PORT_POOL_END: u16 = 65535;
29
30const PORT_POOL_SIZE: usize = (PORT_POOL_END - PORT_POOL_START + 1) as usize;
32
33const BITMAP_WORDS: usize = PORT_POOL_SIZE.div_ceil(64);
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum NatType {
43 Snat,
45 Dnat,
47 Masquerade,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct NatMapping {
58 pub nat_type: NatType,
60 pub original_src_ip: Ipv4Address,
62 pub original_src_port: Port,
64 pub original_dst_ip: Ipv4Address,
66 pub original_dst_port: Port,
68 pub translated_src_ip: Ipv4Address,
70 pub translated_src_port: Port,
72 pub translated_dst_ip: Ipv4Address,
74 pub translated_dst_port: Port,
76}
77
78impl NatMapping {
79 pub fn to_nat_info(&self) -> NatInfo {
81 NatInfo {
82 original_src_ip: self.original_src_ip,
83 original_src_port: self.original_src_port,
84 translated_src_ip: self.translated_src_ip,
85 translated_src_port: self.translated_src_port,
86 original_dst_ip: self.original_dst_ip,
87 original_dst_port: self.original_dst_port,
88 translated_dst_ip: self.translated_dst_ip,
89 translated_dst_port: self.translated_dst_port,
90 }
91 }
92}
93
94pub struct PortPool {
103 bitmap: [u64; BITMAP_WORDS],
105 allocated_count: u16,
107}
108
109impl PortPool {
110 pub fn new() -> Self {
112 Self {
113 bitmap: [0u64; BITMAP_WORDS],
114 allocated_count: 0,
115 }
116 }
117
118 pub fn allocated(&self) -> u16 {
120 self.allocated_count
121 }
122
123 pub fn available(&self) -> u16 {
125 PORT_POOL_SIZE as u16 - self.allocated_count
126 }
127
128 pub fn allocate(&mut self) -> Option<Port> {
132 for (word_idx, word) in self.bitmap.iter_mut().enumerate() {
133 if *word == u64::MAX {
134 continue; }
136 let bit_idx = (!*word).trailing_zeros() as usize;
138 let port_offset = word_idx * 64 + bit_idx;
139 if port_offset >= PORT_POOL_SIZE {
140 return None;
141 }
142 *word |= 1u64 << bit_idx;
143 self.allocated_count += 1;
144 return Some(PORT_POOL_START + port_offset as u16);
145 }
146 None
147 }
148
149 pub fn release(&mut self, port: Port) -> bool {
151 if !(PORT_POOL_START..=PORT_POOL_END).contains(&port) {
152 return false;
153 }
154 let offset = (port - PORT_POOL_START) as usize;
155 let word_idx = offset / 64;
156 let bit_idx = offset % 64;
157 if self.bitmap[word_idx] & (1u64 << bit_idx) != 0 {
158 self.bitmap[word_idx] &= !(1u64 << bit_idx);
159 self.allocated_count -= 1;
160 true
161 } else {
162 false }
164 }
165
166 pub fn is_allocated(&self, port: Port) -> bool {
168 if !(PORT_POOL_START..=PORT_POOL_END).contains(&port) {
169 return false;
170 }
171 let offset = (port - PORT_POOL_START) as usize;
172 let word_idx = offset / 64;
173 let bit_idx = offset % 64;
174 self.bitmap[word_idx] & (1u64 << bit_idx) != 0
175 }
176}
177
178impl Default for PortPool {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184pub fn update_checksum(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
197 let hc = !old_checksum as u32;
199 let m = !old_value as u32;
200 let m_prime = new_value as u32;
201
202 let mut sum = hc + m + m_prime;
203
204 while sum > 0xFFFF {
206 sum = (sum & 0xFFFF) + (sum >> 16);
207 }
208
209 !sum as u16
210}
211
212pub fn update_checksum_32(old_checksum: u16, old_addr: u32, new_addr: u32) -> u16 {
215 let old_hi = (old_addr >> 16) as u16;
216 let old_lo = old_addr as u16;
217 let new_hi = (new_addr >> 16) as u16;
218 let new_lo = new_addr as u16;
219
220 let c1 = update_checksum(old_checksum, old_hi, new_hi);
221 update_checksum(c1, old_lo, new_lo)
222}
223
224pub struct NatEngine {
230 pub port_pool: PortPool,
232 pub mappings: BTreeMap<ConntrackKey, NatMapping>,
234 pub masquerade_addr: Ipv4Address,
236 pub total_translations: u64,
238}
239
240impl NatEngine {
241 pub fn new() -> Self {
243 Self {
244 port_pool: PortPool::new(),
245 mappings: BTreeMap::new(),
246 masquerade_addr: Ipv4Address::ANY,
247 total_translations: 0,
248 }
249 }
250
251 pub fn set_masquerade_addr(&mut self, addr: Ipv4Address) {
253 self.masquerade_addr = addr;
254 }
255
256 pub fn translate_outbound_snat(
261 &mut self,
262 key: &ConntrackKey,
263 new_src_ip: Ipv4Address,
264 ) -> Option<NatMapping> {
265 if let Some(mapping) = self.mappings.get(key) {
267 return Some(*mapping);
268 }
269
270 let new_port = self.port_pool.allocate()?;
272
273 let mapping = NatMapping {
274 nat_type: NatType::Snat,
275 original_src_ip: key.src_ip,
276 original_src_port: key.src_port,
277 original_dst_ip: key.dst_ip,
278 original_dst_port: key.dst_port,
279 translated_src_ip: new_src_ip,
280 translated_src_port: new_port,
281 translated_dst_ip: key.dst_ip,
282 translated_dst_port: key.dst_port,
283 };
284
285 self.mappings.insert(*key, mapping);
286 self.total_translations += 1;
287 Some(mapping)
288 }
289
290 pub fn translate_outbound_masquerade(&mut self, key: &ConntrackKey) -> Option<NatMapping> {
294 let addr = self.masquerade_addr;
295 if addr == Ipv4Address::ANY {
296 return None;
297 }
298
299 if let Some(mapping) = self.mappings.get(key) {
301 return Some(*mapping);
302 }
303
304 let new_port = self.port_pool.allocate()?;
305
306 let mapping = NatMapping {
307 nat_type: NatType::Masquerade,
308 original_src_ip: key.src_ip,
309 original_src_port: key.src_port,
310 original_dst_ip: key.dst_ip,
311 original_dst_port: key.dst_port,
312 translated_src_ip: addr,
313 translated_src_port: new_port,
314 translated_dst_ip: key.dst_ip,
315 translated_dst_port: key.dst_port,
316 };
317
318 self.mappings.insert(*key, mapping);
319 self.total_translations += 1;
320 Some(mapping)
321 }
322
323 pub fn translate_inbound_dnat(
327 &mut self,
328 key: &ConntrackKey,
329 new_dst_ip: Ipv4Address,
330 new_dst_port: Port,
331 ) -> Option<NatMapping> {
332 if let Some(mapping) = self.mappings.get(key) {
334 return Some(*mapping);
335 }
336
337 let mapping = NatMapping {
338 nat_type: NatType::Dnat,
339 original_src_ip: key.src_ip,
340 original_src_port: key.src_port,
341 original_dst_ip: key.dst_ip,
342 original_dst_port: key.dst_port,
343 translated_src_ip: key.src_ip,
344 translated_src_port: key.src_port,
345 translated_dst_ip: new_dst_ip,
346 translated_dst_port: new_dst_port,
347 };
348
349 self.mappings.insert(*key, mapping);
350 self.total_translations += 1;
351 Some(mapping)
352 }
353
354 pub fn lookup_reverse(&self, reply_key: &ConntrackKey) -> Option<&NatMapping> {
359 for mapping in self.mappings.values() {
364 match mapping.nat_type {
365 NatType::Snat | NatType::Masquerade => {
366 if mapping.translated_src_ip == reply_key.dst_ip
367 && mapping.translated_src_port == reply_key.dst_port
368 && mapping.original_dst_ip == reply_key.src_ip
369 {
370 return Some(mapping);
371 }
372 }
373 NatType::Dnat => {
374 if mapping.translated_dst_ip == reply_key.src_ip
375 && mapping.translated_dst_port == reply_key.src_port
376 && mapping.original_src_ip == reply_key.dst_ip
377 {
378 return Some(mapping);
379 }
380 }
381 }
382 }
383 None
384 }
385
386 pub fn remove_mapping(&mut self, key: &ConntrackKey) -> Option<NatMapping> {
388 if let Some(mapping) = self.mappings.remove(key) {
389 match mapping.nat_type {
391 NatType::Snat | NatType::Masquerade => {
392 self.port_pool.release(mapping.translated_src_port);
393 }
394 NatType::Dnat => {}
395 }
396 Some(mapping)
397 } else {
398 None
399 }
400 }
401
402 pub fn mapping_count(&self) -> usize {
404 self.mappings.len()
405 }
406}
407
408impl Default for NatEngine {
409 fn default() -> Self {
410 Self::new()
411 }
412}
413
414static NAT_ENGINE: GlobalState<spin::Mutex<NatEngine>> = GlobalState::new();
419
420pub fn init() -> Result<(), KernelError> {
422 NAT_ENGINE
423 .init(spin::Mutex::new(NatEngine::new()))
424 .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
425 Ok(())
426}
427
428pub fn with_nat<R, F: FnOnce(&mut NatEngine) -> R>(f: F) -> Option<R> {
430 NAT_ENGINE.with(|lock| {
431 let mut engine = lock.lock();
432 f(&mut engine)
433 })
434}
435
436#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_port_pool_allocate() {
446 let mut pool = PortPool::new();
447 let port = pool.allocate().unwrap();
448 assert_eq!(port, PORT_POOL_START);
449 assert_eq!(pool.allocated(), 1);
450 assert!(pool.is_allocated(port));
451 }
452
453 #[test]
454 fn test_port_pool_release() {
455 let mut pool = PortPool::new();
456 let port = pool.allocate().unwrap();
457 assert!(pool.release(port));
458 assert_eq!(pool.allocated(), 0);
459 assert!(!pool.is_allocated(port));
460 }
461
462 #[test]
463 fn test_port_pool_release_invalid() {
464 let mut pool = PortPool::new();
465 assert!(!pool.release(80)); assert!(!pool.release(PORT_POOL_START)); }
468
469 #[test]
470 fn test_port_pool_sequential_allocation() {
471 let mut pool = PortPool::new();
472 let p1 = pool.allocate().unwrap();
473 let p2 = pool.allocate().unwrap();
474 let p3 = pool.allocate().unwrap();
475 assert_eq!(p1, PORT_POOL_START);
476 assert_eq!(p2, PORT_POOL_START + 1);
477 assert_eq!(p3, PORT_POOL_START + 2);
478 assert_eq!(pool.allocated(), 3);
479 }
480
481 #[test]
482 fn test_port_pool_reuse_released() {
483 let mut pool = PortPool::new();
484 let p1 = pool.allocate().unwrap();
485 let _p2 = pool.allocate().unwrap();
486 pool.release(p1);
487 let p3 = pool.allocate().unwrap();
488 assert_eq!(p3, p1); }
490
491 #[test]
492 fn test_checksum_update_identity() {
493 let checksum = 0x1234;
495 let result = update_checksum(checksum, 0xABCD, 0xABCD);
496 assert_eq!(result, checksum);
497 }
498
499 #[test]
500 fn test_checksum_update_basic() {
501 let result = update_checksum(0x0000, 0x5555, 0x5555);
504 assert_eq!(result, 0x0000); }
506
507 #[test]
508 fn test_checksum_update_32_identity() {
509 let checksum = 0xABCD;
510 let addr: u32 = 0xC0A80101; let result = update_checksum_32(checksum, addr, addr);
512 assert_eq!(result, checksum);
513 }
514
515 #[test]
516 fn test_nat_engine_snat() {
517 let mut engine = NatEngine::new();
518 let key = ConntrackKey::new(
519 Ipv4Address::new(192, 168, 1, 100),
520 Ipv4Address::new(8, 8, 8, 8),
521 12345,
522 53,
523 ConntrackKey::PROTO_UDP,
524 );
525 let public_ip = Ipv4Address::new(203, 0, 113, 1);
526
527 let mapping = engine.translate_outbound_snat(&key, public_ip).unwrap();
528 assert_eq!(mapping.nat_type, NatType::Snat);
529 assert_eq!(mapping.original_src_ip, Ipv4Address::new(192, 168, 1, 100));
530 assert_eq!(mapping.translated_src_ip, public_ip);
531 assert!(mapping.translated_src_port >= PORT_POOL_START);
532 assert_eq!(engine.mapping_count(), 1);
533 }
534
535 #[test]
536 fn test_nat_engine_masquerade() {
537 let mut engine = NatEngine::new();
538 engine.set_masquerade_addr(Ipv4Address::new(203, 0, 113, 1));
539 let key = ConntrackKey::new(
540 Ipv4Address::new(192, 168, 1, 50),
541 Ipv4Address::new(1, 1, 1, 1),
542 5000,
543 443,
544 ConntrackKey::PROTO_TCP,
545 );
546
547 let mapping = engine.translate_outbound_masquerade(&key).unwrap();
548 assert_eq!(mapping.nat_type, NatType::Masquerade);
549 assert_eq!(mapping.translated_src_ip, Ipv4Address::new(203, 0, 113, 1));
550 }
551
552 #[test]
553 fn test_nat_engine_masquerade_no_addr() {
554 let mut engine = NatEngine::new();
555 let key = ConntrackKey::new(
557 Ipv4Address::new(192, 168, 1, 50),
558 Ipv4Address::new(1, 1, 1, 1),
559 5000,
560 443,
561 ConntrackKey::PROTO_TCP,
562 );
563 assert!(engine.translate_outbound_masquerade(&key).is_none());
564 }
565
566 #[test]
567 fn test_nat_engine_dnat() {
568 let mut engine = NatEngine::new();
569 let key = ConntrackKey::new(
570 Ipv4Address::new(8, 8, 8, 8),
571 Ipv4Address::new(203, 0, 113, 1),
572 5000,
573 80,
574 ConntrackKey::PROTO_TCP,
575 );
576 let internal_ip = Ipv4Address::new(192, 168, 1, 10);
577
578 let mapping = engine
579 .translate_inbound_dnat(&key, internal_ip, 8080)
580 .unwrap();
581 assert_eq!(mapping.nat_type, NatType::Dnat);
582 assert_eq!(mapping.translated_dst_ip, internal_ip);
583 assert_eq!(mapping.translated_dst_port, 8080);
584 }
585
586 #[test]
587 fn test_nat_engine_remove_mapping() {
588 let mut engine = NatEngine::new();
589 let key = ConntrackKey::new(
590 Ipv4Address::new(192, 168, 1, 100),
591 Ipv4Address::new(8, 8, 8, 8),
592 12345,
593 53,
594 ConntrackKey::PROTO_UDP,
595 );
596 let public_ip = Ipv4Address::new(203, 0, 113, 1);
597
598 let mapping = engine.translate_outbound_snat(&key, public_ip).unwrap();
599 let allocated_port = mapping.translated_src_port;
600 assert!(engine.port_pool.is_allocated(allocated_port));
601
602 engine.remove_mapping(&key);
603 assert_eq!(engine.mapping_count(), 0);
604 assert!(!engine.port_pool.is_allocated(allocated_port));
605 }
606}