1#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use super::rules::Protocol;
16use crate::{error::KernelError, net::Ipv4Address, sync::once_lock::GlobalState};
17
18const MAX_CONNTRACK_ENTRIES: usize = 65536;
24
25const TCP_ESTABLISHED_TIMEOUT: u64 = 7200;
27
28const TCP_NEW_TIMEOUT: u64 = 120;
30
31const TCP_TIME_WAIT_TIMEOUT: u64 = 120;
33
34const UDP_TIMEOUT: u64 = 30;
36
37const ICMP_TIMEOUT: u64 = 30;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
46pub struct ConntrackKey {
47 pub src_ip: Ipv4Address,
49 pub dst_ip: Ipv4Address,
51 pub src_port: u16,
53 pub dst_port: u16,
55 pub protocol: u8,
57}
58
59impl ConntrackKey {
60 pub fn new(
62 src_ip: Ipv4Address,
63 dst_ip: Ipv4Address,
64 src_port: u16,
65 dst_port: u16,
66 protocol: u8,
67 ) -> Self {
68 Self {
69 src_ip,
70 dst_ip,
71 src_port,
72 dst_port,
73 protocol,
74 }
75 }
76
77 pub fn reverse(&self) -> Self {
79 Self {
80 src_ip: self.dst_ip,
81 dst_ip: self.src_ip,
82 src_port: self.dst_port,
83 dst_port: self.src_port,
84 protocol: self.protocol,
85 }
86 }
87
88 pub const PROTO_TCP: u8 = 6;
90 pub const PROTO_UDP: u8 = 17;
92 pub const PROTO_ICMP: u8 = 1;
94
95 pub fn protocol_to_num(proto: Protocol) -> u8 {
97 match proto {
98 Protocol::Tcp => Self::PROTO_TCP,
99 Protocol::Udp => Self::PROTO_UDP,
100 Protocol::Icmp => Self::PROTO_ICMP,
101 Protocol::Icmpv6 => 58,
102 Protocol::Any => 0,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
113pub enum ConntrackState {
114 #[default]
116 New,
117 Established,
119 Related,
121 Invalid,
123 TimeWait,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
129pub enum TcpConnState {
130 #[default]
132 None,
133 SynSent,
135 SynRecv,
137 Established,
139 FinWait1,
141 FinWait2,
143 Closing,
145 CloseWait,
147 LastAck,
149 TimeWait,
151 Closed,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub struct NatInfo {
162 pub original_src_ip: Ipv4Address,
164 pub original_src_port: u16,
165 pub translated_src_ip: Ipv4Address,
167 pub translated_src_port: u16,
168 pub original_dst_ip: Ipv4Address,
170 pub original_dst_port: u16,
171 pub translated_dst_ip: Ipv4Address,
173 pub translated_dst_port: u16,
174}
175
176#[derive(Debug, Clone)]
182pub struct ConntrackEntry {
183 pub key: ConntrackKey,
185 pub state: ConntrackState,
187 pub tcp_state: TcpConnState,
189 pub timeout_ticks: u64,
192 pub packet_count: u64,
194 pub byte_count: u64,
196 pub last_seen: u64,
198 pub nat_info: Option<NatInfo>,
200 pub reply_seen: bool,
202}
203
204impl ConntrackEntry {
205 pub fn new(key: ConntrackKey, protocol: u8) -> Self {
207 let timeout = match protocol {
208 ConntrackKey::PROTO_TCP => TCP_NEW_TIMEOUT,
209 ConntrackKey::PROTO_UDP => UDP_TIMEOUT,
210 ConntrackKey::PROTO_ICMP => ICMP_TIMEOUT,
211 _ => UDP_TIMEOUT,
212 };
213
214 Self {
215 key,
216 state: ConntrackState::New,
217 tcp_state: if protocol == ConntrackKey::PROTO_TCP {
218 TcpConnState::SynSent
219 } else {
220 TcpConnState::None
221 },
222 timeout_ticks: timeout,
223 packet_count: 1,
224 byte_count: 0,
225 last_seen: 0,
226 nat_info: None,
227 reply_seen: false,
228 }
229 }
230
231 pub fn is_expired(&self, current_tick: u64) -> bool {
233 current_tick >= self.last_seen + self.timeout_ticks
234 }
235
236 pub fn update(&mut self, current_tick: u64, bytes: u64) {
238 self.last_seen = current_tick;
239 self.packet_count += 1;
240 self.byte_count += bytes;
241 }
242
243 pub fn mark_reply_seen(&mut self) {
245 if !self.reply_seen {
246 self.reply_seen = true;
247 if self.state == ConntrackState::New {
248 self.state = ConntrackState::Established;
249 if self.key.protocol == ConntrackKey::PROTO_TCP {
250 self.timeout_ticks = TCP_ESTABLISHED_TIMEOUT;
251 }
252 }
253 }
254 }
255}
256
257const TCP_SYN: u8 = 0x02;
263const TCP_ACK: u8 = 0x10;
264const TCP_FIN: u8 = 0x01;
265const TCP_RST: u8 = 0x04;
266
267pub fn update_tcp_state(entry: &mut ConntrackEntry, tcp_flags: u8, is_reply: bool) {
272 let has_syn = tcp_flags & TCP_SYN != 0;
273 let has_ack = tcp_flags & TCP_ACK != 0;
274 let has_fin = tcp_flags & TCP_FIN != 0;
275 let has_rst = tcp_flags & TCP_RST != 0;
276
277 if has_rst {
279 entry.tcp_state = TcpConnState::Closed;
280 entry.state = ConntrackState::Invalid;
281 entry.timeout_ticks = TCP_NEW_TIMEOUT;
282 return;
283 }
284
285 entry.tcp_state = match entry.tcp_state {
286 TcpConnState::None => {
287 if has_syn && !has_ack {
288 TcpConnState::SynSent
289 } else {
290 TcpConnState::None
291 }
292 }
293 TcpConnState::SynSent => {
294 if is_reply && has_syn && has_ack {
295 entry.mark_reply_seen();
296 TcpConnState::SynRecv
297 } else {
298 TcpConnState::SynSent
299 }
300 }
301 TcpConnState::SynRecv => {
302 if !is_reply && has_ack {
303 entry.state = ConntrackState::Established;
304 entry.timeout_ticks = TCP_ESTABLISHED_TIMEOUT;
305 TcpConnState::Established
306 } else {
307 TcpConnState::SynRecv
308 }
309 }
310 TcpConnState::Established => {
311 if has_fin {
312 if is_reply {
313 TcpConnState::CloseWait
314 } else {
315 TcpConnState::FinWait1
316 }
317 } else {
318 TcpConnState::Established
319 }
320 }
321 TcpConnState::FinWait1 => {
322 if is_reply && has_fin && has_ack {
323 entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
324 entry.state = ConntrackState::TimeWait;
325 TcpConnState::TimeWait
326 } else if is_reply && has_ack {
327 TcpConnState::FinWait2
328 } else if is_reply && has_fin {
329 TcpConnState::Closing
330 } else {
331 TcpConnState::FinWait1
332 }
333 }
334 TcpConnState::FinWait2 => {
335 if is_reply && has_fin {
336 entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
337 entry.state = ConntrackState::TimeWait;
338 TcpConnState::TimeWait
339 } else {
340 TcpConnState::FinWait2
341 }
342 }
343 TcpConnState::Closing => {
344 if has_ack {
345 entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
346 entry.state = ConntrackState::TimeWait;
347 TcpConnState::TimeWait
348 } else {
349 TcpConnState::Closing
350 }
351 }
352 TcpConnState::CloseWait => {
353 if !is_reply && has_fin {
354 TcpConnState::LastAck
355 } else {
356 TcpConnState::CloseWait
357 }
358 }
359 TcpConnState::LastAck => {
360 if is_reply && has_ack {
361 entry.state = ConntrackState::TimeWait;
362 entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
363 TcpConnState::Closed
364 } else {
365 TcpConnState::LastAck
366 }
367 }
368 TcpConnState::TimeWait | TcpConnState::Closed => entry.tcp_state,
369 };
370}
371
372pub struct ConntrackTable {
378 entries: BTreeMap<ConntrackKey, ConntrackEntry>,
380 max_entries: usize,
382 current_tick: u64,
384 total_created: u64,
386 total_expired: u64,
388}
389
390impl ConntrackTable {
391 pub fn new() -> Self {
393 Self {
394 entries: BTreeMap::new(),
395 max_entries: MAX_CONNTRACK_ENTRIES,
396 current_tick: 0,
397 total_created: 0,
398 total_expired: 0,
399 }
400 }
401
402 pub fn with_max_entries(max: usize) -> Self {
404 Self {
405 entries: BTreeMap::new(),
406 max_entries: max,
407 current_tick: 0,
408 total_created: 0,
409 total_expired: 0,
410 }
411 }
412
413 pub fn tick(&mut self) {
415 self.current_tick += 1;
416 }
417
418 pub fn set_tick(&mut self, tick: u64) {
420 self.current_tick = tick;
421 }
422
423 pub fn entry_count(&self) -> usize {
425 self.entries.len()
426 }
427
428 pub fn lookup(&self, key: &ConntrackKey) -> Option<&ConntrackEntry> {
430 self.entries.get(key)
431 }
432
433 pub fn lookup_mut(&mut self, key: &ConntrackKey) -> Option<&mut ConntrackEntry> {
435 self.entries.get_mut(key)
436 }
437
438 pub fn track_packet(&mut self, key: ConntrackKey, bytes: u64, tcp_flags: u8) -> ConntrackState {
442 let current_tick = self.current_tick;
443
444 if let Some(entry) = self.entries.get_mut(&key) {
446 entry.update(current_tick, bytes);
447 if key.protocol == ConntrackKey::PROTO_TCP {
448 update_tcp_state(entry, tcp_flags, false);
449 }
450 return entry.state;
451 }
452
453 let reverse_key = key.reverse();
455 if let Some(entry) = self.entries.get_mut(&reverse_key) {
456 entry.update(current_tick, bytes);
457 entry.mark_reply_seen();
458 if reverse_key.protocol == ConntrackKey::PROTO_TCP {
459 update_tcp_state(entry, tcp_flags, true);
460 }
461 return entry.state;
462 }
463
464 if self.entries.len() >= self.max_entries {
466 self.gc();
468 if self.entries.len() >= self.max_entries {
469 return ConntrackState::Invalid;
470 }
471 }
472
473 let mut entry = ConntrackEntry::new(key, key.protocol);
474 entry.last_seen = current_tick;
475 entry.byte_count = bytes;
476 self.entries.insert(key, entry);
477 self.total_created += 1;
478 ConntrackState::New
479 }
480
481 pub fn remove(&mut self, key: &ConntrackKey) -> Option<ConntrackEntry> {
483 self.entries.remove(key)
484 }
485
486 pub fn gc(&mut self) -> usize {
488 let current_tick = self.current_tick;
489 let before = self.entries.len();
490
491 let expired_keys: Vec<ConntrackKey> = self
493 .entries
494 .iter()
495 .filter(|(_, entry)| entry.is_expired(current_tick))
496 .map(|(key, _)| *key)
497 .collect();
498
499 for key in &expired_keys {
500 self.entries.remove(key);
501 }
502
503 let removed = before - self.entries.len();
504 self.total_expired += removed as u64;
505 removed
506 }
507
508 pub fn classify_packet(&self, key: &ConntrackKey) -> ConntrackState {
510 if let Some(entry) = self.entries.get(key) {
512 if entry.is_expired(self.current_tick) {
513 return ConntrackState::Invalid;
514 }
515 return entry.state;
516 }
517
518 let reverse = key.reverse();
520 if let Some(entry) = self.entries.get(&reverse) {
521 if entry.is_expired(self.current_tick) {
522 return ConntrackState::Invalid;
523 }
524 if entry.reply_seen {
525 return ConntrackState::Established;
526 }
527 return ConntrackState::New;
528 }
529
530 ConntrackState::New
531 }
532
533 pub fn stats(&self) -> ConntrackStats {
535 ConntrackStats {
536 active_entries: self.entries.len() as u64,
537 max_entries: self.max_entries as u64,
538 total_created: self.total_created,
539 total_expired: self.total_expired,
540 }
541 }
542}
543
544impl Default for ConntrackTable {
545 fn default() -> Self {
546 Self::new()
547 }
548}
549
550#[derive(Debug, Clone, Copy, Default)]
552pub struct ConntrackStats {
553 pub active_entries: u64,
554 pub max_entries: u64,
555 pub total_created: u64,
556 pub total_expired: u64,
557}
558
559static CONNTRACK_TABLE: GlobalState<spin::Mutex<ConntrackTable>> = GlobalState::new();
564
565pub fn init() -> Result<(), KernelError> {
567 CONNTRACK_TABLE
568 .init(spin::Mutex::new(ConntrackTable::new()))
569 .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
570 Ok(())
571}
572
573pub fn with_conntrack<R, F: FnOnce(&mut ConntrackTable) -> R>(f: F) -> Option<R> {
575 CONNTRACK_TABLE.with(|lock| {
576 let mut table = lock.lock();
577 f(&mut table)
578 })
579}
580
581#[cfg(test)]
586mod tests {
587 use super::*;
588
589 fn tcp_key() -> ConntrackKey {
590 ConntrackKey::new(
591 Ipv4Address::new(192, 168, 1, 100),
592 Ipv4Address::new(10, 0, 0, 1),
593 12345,
594 80,
595 ConntrackKey::PROTO_TCP,
596 )
597 }
598
599 fn udp_key() -> ConntrackKey {
600 ConntrackKey::new(
601 Ipv4Address::new(192, 168, 1, 100),
602 Ipv4Address::new(10, 0, 0, 1),
603 5000,
604 53,
605 ConntrackKey::PROTO_UDP,
606 )
607 }
608
609 #[test]
610 fn test_conntrack_key_reverse() {
611 let key = tcp_key();
612 let rev = key.reverse();
613 assert_eq!(rev.src_ip, key.dst_ip);
614 assert_eq!(rev.dst_ip, key.src_ip);
615 assert_eq!(rev.src_port, key.dst_port);
616 assert_eq!(rev.dst_port, key.src_port);
617 assert_eq!(rev.protocol, key.protocol);
618 }
619
620 #[test]
621 fn test_conntrack_state_default() {
622 assert_eq!(ConntrackState::default(), ConntrackState::New);
623 }
624
625 #[test]
626 fn test_conntrack_entry_new_tcp() {
627 let entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
628 assert_eq!(entry.state, ConntrackState::New);
629 assert_eq!(entry.tcp_state, TcpConnState::SynSent);
630 assert_eq!(entry.timeout_ticks, TCP_NEW_TIMEOUT);
631 assert!(!entry.reply_seen);
632 }
633
634 #[test]
635 fn test_conntrack_entry_new_udp() {
636 let entry = ConntrackEntry::new(udp_key(), ConntrackKey::PROTO_UDP);
637 assert_eq!(entry.state, ConntrackState::New);
638 assert_eq!(entry.tcp_state, TcpConnState::None);
639 assert_eq!(entry.timeout_ticks, UDP_TIMEOUT);
640 }
641
642 #[test]
643 fn test_conntrack_entry_expired() {
644 let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
645 entry.last_seen = 100;
646 entry.timeout_ticks = 50;
647 assert!(entry.is_expired(151));
648 assert!(!entry.is_expired(149));
649 assert!(entry.is_expired(150));
650 }
651
652 #[test]
653 fn test_conntrack_table_track_new() {
654 let mut table = ConntrackTable::new();
655 let state = table.track_packet(tcp_key(), 64, TCP_SYN);
656 assert_eq!(state, ConntrackState::New);
657 assert_eq!(table.entry_count(), 1);
658 }
659
660 #[test]
661 fn test_conntrack_table_track_reply() {
662 let mut table = ConntrackTable::new();
663 let key = tcp_key();
664
665 table.track_packet(key, 64, TCP_SYN);
667
668 let rev = key.reverse();
670 let state = table.track_packet(rev, 64, TCP_SYN | TCP_ACK);
671 assert_eq!(state, ConntrackState::Established);
672 assert_eq!(table.entry_count(), 1); }
674
675 #[test]
676 fn test_conntrack_table_max_entries() {
677 let mut table = ConntrackTable::with_max_entries(2);
678 let k1 = ConntrackKey::new(
679 Ipv4Address::new(10, 0, 0, 1),
680 Ipv4Address::new(10, 0, 0, 2),
681 1000,
682 80,
683 ConntrackKey::PROTO_TCP,
684 );
685 let k2 = ConntrackKey::new(
686 Ipv4Address::new(10, 0, 0, 3),
687 Ipv4Address::new(10, 0, 0, 4),
688 1001,
689 80,
690 ConntrackKey::PROTO_TCP,
691 );
692 let k3 = ConntrackKey::new(
693 Ipv4Address::new(10, 0, 0, 5),
694 Ipv4Address::new(10, 0, 0, 6),
695 1002,
696 80,
697 ConntrackKey::PROTO_TCP,
698 );
699
700 table.track_packet(k1, 64, TCP_SYN);
701 table.track_packet(k2, 64, TCP_SYN);
702
703 let state = table.track_packet(k3, 64, TCP_SYN);
705 assert_eq!(state, ConntrackState::Invalid);
706 assert_eq!(table.entry_count(), 2);
707 }
708
709 #[test]
710 fn test_conntrack_table_gc() {
711 let mut table = ConntrackTable::new();
712 let key = tcp_key();
713
714 table.track_packet(key, 64, TCP_SYN);
715 assert_eq!(table.entry_count(), 1);
716
717 table.set_tick(TCP_NEW_TIMEOUT + 1);
719 let removed = table.gc();
720 assert_eq!(removed, 1);
721 assert_eq!(table.entry_count(), 0);
722 }
723
724 #[test]
725 fn test_conntrack_classify_unknown() {
726 let table = ConntrackTable::new();
727 let state = table.classify_packet(&tcp_key());
728 assert_eq!(state, ConntrackState::New);
729 }
730
731 #[test]
732 fn test_tcp_state_full_handshake() {
733 let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
734
735 assert_eq!(entry.tcp_state, TcpConnState::SynSent);
737
738 update_tcp_state(&mut entry, TCP_SYN | TCP_ACK, true);
740 assert_eq!(entry.tcp_state, TcpConnState::SynRecv);
741
742 update_tcp_state(&mut entry, TCP_ACK, false);
744 assert_eq!(entry.tcp_state, TcpConnState::Established);
745 assert_eq!(entry.state, ConntrackState::Established);
746 }
747
748 #[test]
749 fn test_tcp_state_rst() {
750 let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
751 update_tcp_state(&mut entry, TCP_RST, false);
752 assert_eq!(entry.tcp_state, TcpConnState::Closed);
753 assert_eq!(entry.state, ConntrackState::Invalid);
754 }
755
756 #[test]
757 fn test_tcp_state_fin_close() {
758 let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
759
760 update_tcp_state(&mut entry, TCP_SYN | TCP_ACK, true);
762 update_tcp_state(&mut entry, TCP_ACK, false);
763 assert_eq!(entry.tcp_state, TcpConnState::Established);
764
765 update_tcp_state(&mut entry, TCP_FIN, false);
767 assert_eq!(entry.tcp_state, TcpConnState::FinWait1);
768
769 update_tcp_state(&mut entry, TCP_FIN | TCP_ACK, true);
771 assert_eq!(entry.tcp_state, TcpConnState::TimeWait);
772 assert_eq!(entry.state, ConntrackState::TimeWait);
773 }
774
775 #[test]
776 fn test_conntrack_stats() {
777 let mut table = ConntrackTable::new();
778 table.track_packet(tcp_key(), 64, TCP_SYN);
779 let stats = table.stats();
780 assert_eq!(stats.active_entries, 1);
781 assert_eq!(stats.total_created, 1);
782 assert_eq!(stats.max_entries, MAX_CONNTRACK_ENTRIES as u64);
783 }
784}