⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/firewall/
conntrack.rs

1//! Connection tracking (conntrack) for stateful packet inspection
2//!
3//! Tracks network connections using 5-tuple keys (src_ip, dst_ip, src_port,
4//! dst_port, protocol). Maintains TCP state machine for accurate connection
5//! lifecycle tracking. Supports garbage collection of expired entries and
6//! enforces a maximum entry limit to prevent resource exhaustion.
7
8#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use super::rules::Protocol;
16use crate::{error::KernelError, net::Ipv4Address, sync::once_lock::GlobalState};
17
18// ============================================================================
19// Constants
20// ============================================================================
21
22/// Maximum number of connection tracking entries
23const MAX_CONNTRACK_ENTRIES: usize = 65536;
24
25/// Default timeout for established TCP connections (in ticks, ~7200s at 1Hz)
26const TCP_ESTABLISHED_TIMEOUT: u64 = 7200;
27
28/// Default timeout for new/half-open connections (in ticks, ~120s)
29const TCP_NEW_TIMEOUT: u64 = 120;
30
31/// Default timeout for TIME_WAIT state (in ticks, ~120s = 2*MSL)
32const TCP_TIME_WAIT_TIMEOUT: u64 = 120;
33
34/// Default timeout for UDP connections (in ticks, ~30s)
35const UDP_TIMEOUT: u64 = 30;
36
37/// Default timeout for ICMP entries (in ticks, ~30s)
38const ICMP_TIMEOUT: u64 = 30;
39
40// ============================================================================
41// Connection Tracking Key
42// ============================================================================
43
44/// 5-tuple identifying a unique connection
45#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
46pub struct ConntrackKey {
47    /// Source IP address
48    pub src_ip: Ipv4Address,
49    /// Destination IP address
50    pub dst_ip: Ipv4Address,
51    /// Source port (0 for ICMP)
52    pub src_port: u16,
53    /// Destination port (0 for ICMP)
54    pub dst_port: u16,
55    /// IP protocol
56    pub protocol: u8,
57}
58
59impl ConntrackKey {
60    /// Create a new connection tracking key
61    pub fn new(
62        src_ip: Ipv4Address,
63        dst_ip: Ipv4Address,
64        src_port: u16,
65        dst_port: u16,
66        protocol: u8,
67    ) -> Self {
68        Self {
69            src_ip,
70            dst_ip,
71            src_port,
72            dst_port,
73            protocol,
74        }
75    }
76
77    /// Create the reverse key (for tracking reply direction)
78    pub fn reverse(&self) -> Self {
79        Self {
80            src_ip: self.dst_ip,
81            dst_ip: self.src_ip,
82            src_port: self.dst_port,
83            dst_port: self.src_port,
84            protocol: self.protocol,
85        }
86    }
87
88    /// Protocol number for TCP
89    pub const PROTO_TCP: u8 = 6;
90    /// Protocol number for UDP
91    pub const PROTO_UDP: u8 = 17;
92    /// Protocol number for ICMP
93    pub const PROTO_ICMP: u8 = 1;
94
95    /// Convert from rules::Protocol to protocol number
96    pub fn protocol_to_num(proto: Protocol) -> u8 {
97        match proto {
98            Protocol::Tcp => Self::PROTO_TCP,
99            Protocol::Udp => Self::PROTO_UDP,
100            Protocol::Icmp => Self::PROTO_ICMP,
101            Protocol::Icmpv6 => 58,
102            Protocol::Any => 0,
103        }
104    }
105}
106
107// ============================================================================
108// Connection State
109// ============================================================================
110
111/// High-level connection tracking state
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
113pub enum ConntrackState {
114    /// First packet of a new connection
115    #[default]
116    New,
117    /// Connection has seen traffic in both directions
118    Established,
119    /// Related to an existing connection (e.g., ICMP error, FTP data)
120    Related,
121    /// Invalid or unexpected packet
122    Invalid,
123    /// TCP TIME_WAIT state
124    TimeWait,
125}
126
127/// Detailed TCP connection state machine
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
129pub enum TcpConnState {
130    /// No connection
131    #[default]
132    None,
133    /// SYN sent by originator
134    SynSent,
135    /// SYN-ACK received (SYN sent by responder)
136    SynRecv,
137    /// Three-way handshake complete
138    Established,
139    /// FIN sent by originator
140    FinWait1,
141    /// FIN acknowledged by responder
142    FinWait2,
143    /// Both sides closing
144    Closing,
145    /// FIN received while in ESTABLISHED
146    CloseWait,
147    /// FIN sent after CLOSE_WAIT
148    LastAck,
149    /// Waiting for old duplicates to expire
150    TimeWait,
151    /// Connection fully closed
152    Closed,
153}
154
155// ============================================================================
156// NAT Info (stored in conntrack entries)
157// ============================================================================
158
159/// NAT information associated with a connection
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub struct NatInfo {
162    /// Original source before SNAT
163    pub original_src_ip: Ipv4Address,
164    pub original_src_port: u16,
165    /// Translated source after SNAT
166    pub translated_src_ip: Ipv4Address,
167    pub translated_src_port: u16,
168    /// Original destination before DNAT
169    pub original_dst_ip: Ipv4Address,
170    pub original_dst_port: u16,
171    /// Translated destination after DNAT
172    pub translated_dst_ip: Ipv4Address,
173    pub translated_dst_port: u16,
174}
175
176// ============================================================================
177// Connection Tracking Entry
178// ============================================================================
179
180/// A single connection tracking entry
181#[derive(Debug, Clone)]
182pub struct ConntrackEntry {
183    /// Connection 5-tuple key
184    pub key: ConntrackKey,
185    /// High-level connection state
186    pub state: ConntrackState,
187    /// Detailed TCP state (only meaningful for TCP)
188    pub tcp_state: TcpConnState,
189    /// Timeout in ticks (entry expires when current_tick >= last_seen +
190    /// timeout)
191    pub timeout_ticks: u64,
192    /// Number of packets seen
193    pub packet_count: u64,
194    /// Number of bytes seen
195    pub byte_count: u64,
196    /// Tick counter when entry was last updated
197    pub last_seen: u64,
198    /// Optional NAT translation info
199    pub nat_info: Option<NatInfo>,
200    /// Whether reply traffic has been seen
201    pub reply_seen: bool,
202}
203
204impl ConntrackEntry {
205    /// Create a new conntrack entry for a first-seen packet
206    pub fn new(key: ConntrackKey, protocol: u8) -> Self {
207        let timeout = match protocol {
208            ConntrackKey::PROTO_TCP => TCP_NEW_TIMEOUT,
209            ConntrackKey::PROTO_UDP => UDP_TIMEOUT,
210            ConntrackKey::PROTO_ICMP => ICMP_TIMEOUT,
211            _ => UDP_TIMEOUT,
212        };
213
214        Self {
215            key,
216            state: ConntrackState::New,
217            tcp_state: if protocol == ConntrackKey::PROTO_TCP {
218                TcpConnState::SynSent
219            } else {
220                TcpConnState::None
221            },
222            timeout_ticks: timeout,
223            packet_count: 1,
224            byte_count: 0,
225            last_seen: 0,
226            nat_info: None,
227            reply_seen: false,
228        }
229    }
230
231    /// Check if this entry has expired
232    pub fn is_expired(&self, current_tick: u64) -> bool {
233        current_tick >= self.last_seen + self.timeout_ticks
234    }
235
236    /// Update the last-seen timestamp and packet counters
237    pub fn update(&mut self, current_tick: u64, bytes: u64) {
238        self.last_seen = current_tick;
239        self.packet_count += 1;
240        self.byte_count += bytes;
241    }
242
243    /// Mark that reply traffic has been seen, promoting to Established
244    pub fn mark_reply_seen(&mut self) {
245        if !self.reply_seen {
246            self.reply_seen = true;
247            if self.state == ConntrackState::New {
248                self.state = ConntrackState::Established;
249                if self.key.protocol == ConntrackKey::PROTO_TCP {
250                    self.timeout_ticks = TCP_ESTABLISHED_TIMEOUT;
251                }
252            }
253        }
254    }
255}
256
257// ============================================================================
258// TCP State Machine
259// ============================================================================
260
261/// TCP flag constants for state machine transitions
262const TCP_SYN: u8 = 0x02;
263const TCP_ACK: u8 = 0x10;
264const TCP_FIN: u8 = 0x01;
265const TCP_RST: u8 = 0x04;
266
267/// Update the TCP connection state based on observed flags
268///
269/// `is_reply` indicates whether this packet is from the responder (reply
270/// direction).
271pub fn update_tcp_state(entry: &mut ConntrackEntry, tcp_flags: u8, is_reply: bool) {
272    let has_syn = tcp_flags & TCP_SYN != 0;
273    let has_ack = tcp_flags & TCP_ACK != 0;
274    let has_fin = tcp_flags & TCP_FIN != 0;
275    let has_rst = tcp_flags & TCP_RST != 0;
276
277    // RST immediately closes the connection
278    if has_rst {
279        entry.tcp_state = TcpConnState::Closed;
280        entry.state = ConntrackState::Invalid;
281        entry.timeout_ticks = TCP_NEW_TIMEOUT;
282        return;
283    }
284
285    entry.tcp_state = match entry.tcp_state {
286        TcpConnState::None => {
287            if has_syn && !has_ack {
288                TcpConnState::SynSent
289            } else {
290                TcpConnState::None
291            }
292        }
293        TcpConnState::SynSent => {
294            if is_reply && has_syn && has_ack {
295                entry.mark_reply_seen();
296                TcpConnState::SynRecv
297            } else {
298                TcpConnState::SynSent
299            }
300        }
301        TcpConnState::SynRecv => {
302            if !is_reply && has_ack {
303                entry.state = ConntrackState::Established;
304                entry.timeout_ticks = TCP_ESTABLISHED_TIMEOUT;
305                TcpConnState::Established
306            } else {
307                TcpConnState::SynRecv
308            }
309        }
310        TcpConnState::Established => {
311            if has_fin {
312                if is_reply {
313                    TcpConnState::CloseWait
314                } else {
315                    TcpConnState::FinWait1
316                }
317            } else {
318                TcpConnState::Established
319            }
320        }
321        TcpConnState::FinWait1 => {
322            if is_reply && has_fin && has_ack {
323                entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
324                entry.state = ConntrackState::TimeWait;
325                TcpConnState::TimeWait
326            } else if is_reply && has_ack {
327                TcpConnState::FinWait2
328            } else if is_reply && has_fin {
329                TcpConnState::Closing
330            } else {
331                TcpConnState::FinWait1
332            }
333        }
334        TcpConnState::FinWait2 => {
335            if is_reply && has_fin {
336                entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
337                entry.state = ConntrackState::TimeWait;
338                TcpConnState::TimeWait
339            } else {
340                TcpConnState::FinWait2
341            }
342        }
343        TcpConnState::Closing => {
344            if has_ack {
345                entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
346                entry.state = ConntrackState::TimeWait;
347                TcpConnState::TimeWait
348            } else {
349                TcpConnState::Closing
350            }
351        }
352        TcpConnState::CloseWait => {
353            if !is_reply && has_fin {
354                TcpConnState::LastAck
355            } else {
356                TcpConnState::CloseWait
357            }
358        }
359        TcpConnState::LastAck => {
360            if is_reply && has_ack {
361                entry.state = ConntrackState::TimeWait;
362                entry.timeout_ticks = TCP_TIME_WAIT_TIMEOUT;
363                TcpConnState::Closed
364            } else {
365                TcpConnState::LastAck
366            }
367        }
368        TcpConnState::TimeWait | TcpConnState::Closed => entry.tcp_state,
369    };
370}
371
372// ============================================================================
373// Connection Tracking Table
374// ============================================================================
375
376/// Connection tracking table managing all active connections
377pub struct ConntrackTable {
378    /// Active connections indexed by 5-tuple key
379    entries: BTreeMap<ConntrackKey, ConntrackEntry>,
380    /// Maximum number of entries
381    max_entries: usize,
382    /// Current tick counter (monotonically increasing)
383    current_tick: u64,
384    /// Total entries created over lifetime
385    total_created: u64,
386    /// Total entries expired/garbage-collected
387    total_expired: u64,
388}
389
390impl ConntrackTable {
391    /// Create a new connection tracking table
392    pub fn new() -> Self {
393        Self {
394            entries: BTreeMap::new(),
395            max_entries: MAX_CONNTRACK_ENTRIES,
396            current_tick: 0,
397            total_created: 0,
398            total_expired: 0,
399        }
400    }
401
402    /// Create with a custom maximum entry count
403    pub fn with_max_entries(max: usize) -> Self {
404        Self {
405            entries: BTreeMap::new(),
406            max_entries: max,
407            current_tick: 0,
408            total_created: 0,
409            total_expired: 0,
410        }
411    }
412
413    /// Advance the tick counter
414    pub fn tick(&mut self) {
415        self.current_tick += 1;
416    }
417
418    /// Set the current tick counter
419    pub fn set_tick(&mut self, tick: u64) {
420        self.current_tick = tick;
421    }
422
423    /// Number of active entries
424    pub fn entry_count(&self) -> usize {
425        self.entries.len()
426    }
427
428    /// Look up an entry by key
429    pub fn lookup(&self, key: &ConntrackKey) -> Option<&ConntrackEntry> {
430        self.entries.get(key)
431    }
432
433    /// Look up an entry mutably
434    pub fn lookup_mut(&mut self, key: &ConntrackKey) -> Option<&mut ConntrackEntry> {
435        self.entries.get_mut(key)
436    }
437
438    /// Insert or update a connection tracking entry
439    ///
440    /// Returns the classified state for this packet.
441    pub fn track_packet(&mut self, key: ConntrackKey, bytes: u64, tcp_flags: u8) -> ConntrackState {
442        let current_tick = self.current_tick;
443
444        // Check forward direction
445        if let Some(entry) = self.entries.get_mut(&key) {
446            entry.update(current_tick, bytes);
447            if key.protocol == ConntrackKey::PROTO_TCP {
448                update_tcp_state(entry, tcp_flags, false);
449            }
450            return entry.state;
451        }
452
453        // Check reverse direction (reply packet)
454        let reverse_key = key.reverse();
455        if let Some(entry) = self.entries.get_mut(&reverse_key) {
456            entry.update(current_tick, bytes);
457            entry.mark_reply_seen();
458            if reverse_key.protocol == ConntrackKey::PROTO_TCP {
459                update_tcp_state(entry, tcp_flags, true);
460            }
461            return entry.state;
462        }
463
464        // New connection
465        if self.entries.len() >= self.max_entries {
466            // Table full -- run garbage collection and try again
467            self.gc();
468            if self.entries.len() >= self.max_entries {
469                return ConntrackState::Invalid;
470            }
471        }
472
473        let mut entry = ConntrackEntry::new(key, key.protocol);
474        entry.last_seen = current_tick;
475        entry.byte_count = bytes;
476        self.entries.insert(key, entry);
477        self.total_created += 1;
478        ConntrackState::New
479    }
480
481    /// Remove an entry by key
482    pub fn remove(&mut self, key: &ConntrackKey) -> Option<ConntrackEntry> {
483        self.entries.remove(key)
484    }
485
486    /// Garbage collect expired entries
487    pub fn gc(&mut self) -> usize {
488        let current_tick = self.current_tick;
489        let before = self.entries.len();
490
491        // Collect expired keys
492        let expired_keys: Vec<ConntrackKey> = self
493            .entries
494            .iter()
495            .filter(|(_, entry)| entry.is_expired(current_tick))
496            .map(|(key, _)| *key)
497            .collect();
498
499        for key in &expired_keys {
500            self.entries.remove(key);
501        }
502
503        let removed = before - self.entries.len();
504        self.total_expired += removed as u64;
505        removed
506    }
507
508    /// Classify a packet based on existing connection state
509    pub fn classify_packet(&self, key: &ConntrackKey) -> ConntrackState {
510        // Check forward direction
511        if let Some(entry) = self.entries.get(key) {
512            if entry.is_expired(self.current_tick) {
513                return ConntrackState::Invalid;
514            }
515            return entry.state;
516        }
517
518        // Check reverse direction
519        let reverse = key.reverse();
520        if let Some(entry) = self.entries.get(&reverse) {
521            if entry.is_expired(self.current_tick) {
522                return ConntrackState::Invalid;
523            }
524            if entry.reply_seen {
525                return ConntrackState::Established;
526            }
527            return ConntrackState::New;
528        }
529
530        ConntrackState::New
531    }
532
533    /// Get statistics
534    pub fn stats(&self) -> ConntrackStats {
535        ConntrackStats {
536            active_entries: self.entries.len() as u64,
537            max_entries: self.max_entries as u64,
538            total_created: self.total_created,
539            total_expired: self.total_expired,
540        }
541    }
542}
543
544impl Default for ConntrackTable {
545    fn default() -> Self {
546        Self::new()
547    }
548}
549
550/// Connection tracking statistics
551#[derive(Debug, Clone, Copy, Default)]
552pub struct ConntrackStats {
553    pub active_entries: u64,
554    pub max_entries: u64,
555    pub total_created: u64,
556    pub total_expired: u64,
557}
558
559// ============================================================================
560// Global State
561// ============================================================================
562
563static CONNTRACK_TABLE: GlobalState<spin::Mutex<ConntrackTable>> = GlobalState::new();
564
565/// Initialize the connection tracking subsystem
566pub fn init() -> Result<(), KernelError> {
567    CONNTRACK_TABLE
568        .init(spin::Mutex::new(ConntrackTable::new()))
569        .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
570    Ok(())
571}
572
573/// Access the global conntrack table
574pub fn with_conntrack<R, F: FnOnce(&mut ConntrackTable) -> R>(f: F) -> Option<R> {
575    CONNTRACK_TABLE.with(|lock| {
576        let mut table = lock.lock();
577        f(&mut table)
578    })
579}
580
581// ============================================================================
582// Tests
583// ============================================================================
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    fn tcp_key() -> ConntrackKey {
590        ConntrackKey::new(
591            Ipv4Address::new(192, 168, 1, 100),
592            Ipv4Address::new(10, 0, 0, 1),
593            12345,
594            80,
595            ConntrackKey::PROTO_TCP,
596        )
597    }
598
599    fn udp_key() -> ConntrackKey {
600        ConntrackKey::new(
601            Ipv4Address::new(192, 168, 1, 100),
602            Ipv4Address::new(10, 0, 0, 1),
603            5000,
604            53,
605            ConntrackKey::PROTO_UDP,
606        )
607    }
608
609    #[test]
610    fn test_conntrack_key_reverse() {
611        let key = tcp_key();
612        let rev = key.reverse();
613        assert_eq!(rev.src_ip, key.dst_ip);
614        assert_eq!(rev.dst_ip, key.src_ip);
615        assert_eq!(rev.src_port, key.dst_port);
616        assert_eq!(rev.dst_port, key.src_port);
617        assert_eq!(rev.protocol, key.protocol);
618    }
619
620    #[test]
621    fn test_conntrack_state_default() {
622        assert_eq!(ConntrackState::default(), ConntrackState::New);
623    }
624
625    #[test]
626    fn test_conntrack_entry_new_tcp() {
627        let entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
628        assert_eq!(entry.state, ConntrackState::New);
629        assert_eq!(entry.tcp_state, TcpConnState::SynSent);
630        assert_eq!(entry.timeout_ticks, TCP_NEW_TIMEOUT);
631        assert!(!entry.reply_seen);
632    }
633
634    #[test]
635    fn test_conntrack_entry_new_udp() {
636        let entry = ConntrackEntry::new(udp_key(), ConntrackKey::PROTO_UDP);
637        assert_eq!(entry.state, ConntrackState::New);
638        assert_eq!(entry.tcp_state, TcpConnState::None);
639        assert_eq!(entry.timeout_ticks, UDP_TIMEOUT);
640    }
641
642    #[test]
643    fn test_conntrack_entry_expired() {
644        let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
645        entry.last_seen = 100;
646        entry.timeout_ticks = 50;
647        assert!(entry.is_expired(151));
648        assert!(!entry.is_expired(149));
649        assert!(entry.is_expired(150));
650    }
651
652    #[test]
653    fn test_conntrack_table_track_new() {
654        let mut table = ConntrackTable::new();
655        let state = table.track_packet(tcp_key(), 64, TCP_SYN);
656        assert_eq!(state, ConntrackState::New);
657        assert_eq!(table.entry_count(), 1);
658    }
659
660    #[test]
661    fn test_conntrack_table_track_reply() {
662        let mut table = ConntrackTable::new();
663        let key = tcp_key();
664
665        // Original SYN
666        table.track_packet(key, 64, TCP_SYN);
667
668        // Reply SYN-ACK
669        let rev = key.reverse();
670        let state = table.track_packet(rev, 64, TCP_SYN | TCP_ACK);
671        assert_eq!(state, ConntrackState::Established);
672        assert_eq!(table.entry_count(), 1); // Still one entry (reverse lookup)
673    }
674
675    #[test]
676    fn test_conntrack_table_max_entries() {
677        let mut table = ConntrackTable::with_max_entries(2);
678        let k1 = ConntrackKey::new(
679            Ipv4Address::new(10, 0, 0, 1),
680            Ipv4Address::new(10, 0, 0, 2),
681            1000,
682            80,
683            ConntrackKey::PROTO_TCP,
684        );
685        let k2 = ConntrackKey::new(
686            Ipv4Address::new(10, 0, 0, 3),
687            Ipv4Address::new(10, 0, 0, 4),
688            1001,
689            80,
690            ConntrackKey::PROTO_TCP,
691        );
692        let k3 = ConntrackKey::new(
693            Ipv4Address::new(10, 0, 0, 5),
694            Ipv4Address::new(10, 0, 0, 6),
695            1002,
696            80,
697            ConntrackKey::PROTO_TCP,
698        );
699
700        table.track_packet(k1, 64, TCP_SYN);
701        table.track_packet(k2, 64, TCP_SYN);
702
703        // Table full, no expired entries -> Invalid
704        let state = table.track_packet(k3, 64, TCP_SYN);
705        assert_eq!(state, ConntrackState::Invalid);
706        assert_eq!(table.entry_count(), 2);
707    }
708
709    #[test]
710    fn test_conntrack_table_gc() {
711        let mut table = ConntrackTable::new();
712        let key = tcp_key();
713
714        table.track_packet(key, 64, TCP_SYN);
715        assert_eq!(table.entry_count(), 1);
716
717        // Advance past timeout
718        table.set_tick(TCP_NEW_TIMEOUT + 1);
719        let removed = table.gc();
720        assert_eq!(removed, 1);
721        assert_eq!(table.entry_count(), 0);
722    }
723
724    #[test]
725    fn test_conntrack_classify_unknown() {
726        let table = ConntrackTable::new();
727        let state = table.classify_packet(&tcp_key());
728        assert_eq!(state, ConntrackState::New);
729    }
730
731    #[test]
732    fn test_tcp_state_full_handshake() {
733        let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
734
735        // Client SYN
736        assert_eq!(entry.tcp_state, TcpConnState::SynSent);
737
738        // Server SYN-ACK
739        update_tcp_state(&mut entry, TCP_SYN | TCP_ACK, true);
740        assert_eq!(entry.tcp_state, TcpConnState::SynRecv);
741
742        // Client ACK
743        update_tcp_state(&mut entry, TCP_ACK, false);
744        assert_eq!(entry.tcp_state, TcpConnState::Established);
745        assert_eq!(entry.state, ConntrackState::Established);
746    }
747
748    #[test]
749    fn test_tcp_state_rst() {
750        let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
751        update_tcp_state(&mut entry, TCP_RST, false);
752        assert_eq!(entry.tcp_state, TcpConnState::Closed);
753        assert_eq!(entry.state, ConntrackState::Invalid);
754    }
755
756    #[test]
757    fn test_tcp_state_fin_close() {
758        let mut entry = ConntrackEntry::new(tcp_key(), ConntrackKey::PROTO_TCP);
759
760        // Establish first
761        update_tcp_state(&mut entry, TCP_SYN | TCP_ACK, true);
762        update_tcp_state(&mut entry, TCP_ACK, false);
763        assert_eq!(entry.tcp_state, TcpConnState::Established);
764
765        // Client FIN
766        update_tcp_state(&mut entry, TCP_FIN, false);
767        assert_eq!(entry.tcp_state, TcpConnState::FinWait1);
768
769        // Server FIN+ACK
770        update_tcp_state(&mut entry, TCP_FIN | TCP_ACK, true);
771        assert_eq!(entry.tcp_state, TcpConnState::TimeWait);
772        assert_eq!(entry.state, ConntrackState::TimeWait);
773    }
774
775    #[test]
776    fn test_conntrack_stats() {
777        let mut table = ConntrackTable::new();
778        table.track_packet(tcp_key(), 64, TCP_SYN);
779        let stats = table.stats();
780        assert_eq!(stats.active_entries, 1);
781        assert_eq!(stats.total_created, 1);
782        assert_eq!(stats.max_entries, MAX_CONNTRACK_ENTRIES as u64);
783    }
784}