⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/firewall/
rules.rs

1//! Firewall rule matching and evaluation
2//!
3//! Provides rule definitions with match criteria (source/dest IP with CIDR,
4//! port ranges, protocol, TCP flags, connection state) and actions
5//! (Accept, Drop, Reject, Log, Jump, Masquerade, SNAT, DNAT).
6//! CIDR matching uses bitmask comparison for efficient subnet checks.
7
8#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12#[cfg(feature = "alloc")]
13use alloc::string::String;
14#[cfg(feature = "alloc")]
15use alloc::vec::Vec;
16
17use super::conntrack::ConntrackState;
18use crate::net::{Ipv4Address, Port};
19
20// ============================================================================
21// Protocol
22// ============================================================================
23
24/// IP protocol for rule matching
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum Protocol {
27    /// Match any protocol
28    #[default]
29    Any,
30    /// TCP (protocol 6)
31    Tcp,
32    /// UDP (protocol 17)
33    Udp,
34    /// ICMP (protocol 1)
35    Icmp,
36    /// ICMPv6 (protocol 58)
37    Icmpv6,
38}
39
40// ============================================================================
41// TCP Flags
42// ============================================================================
43
44/// TCP flag bitmask for matching
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
46pub struct TcpFlags {
47    /// Flags that must be set
48    pub mask: u8,
49    /// Expected value after masking
50    pub value: u8,
51}
52
53impl TcpFlags {
54    pub const SYN: u8 = 0x02;
55    pub const ACK: u8 = 0x10;
56    pub const FIN: u8 = 0x01;
57    pub const RST: u8 = 0x04;
58    pub const PSH: u8 = 0x08;
59    pub const URG: u8 = 0x20;
60
61    /// Create a new TCP flags match
62    pub const fn new(mask: u8, value: u8) -> Self {
63        Self { mask, value }
64    }
65
66    /// Match SYN-only packets (SYN set, ACK cleared)
67    pub const fn syn_only() -> Self {
68        Self {
69            mask: Self::SYN | Self::ACK,
70            value: Self::SYN,
71        }
72    }
73
74    /// Match established connection packets (ACK set)
75    pub const fn established() -> Self {
76        Self {
77            mask: Self::ACK,
78            value: Self::ACK,
79        }
80    }
81
82    /// Check if the given flags match this criteria
83    pub fn matches(&self, flags: u8) -> bool {
84        (flags & self.mask) == self.value
85    }
86}
87
88// ============================================================================
89// Port Range
90// ============================================================================
91
92/// A range of ports for matching (inclusive on both ends)
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct PortRange {
95    pub start: Port,
96    pub end: Port,
97}
98
99impl PortRange {
100    /// Create a range matching a single port
101    pub const fn single(port: Port) -> Self {
102        Self {
103            start: port,
104            end: port,
105        }
106    }
107
108    /// Create a range of ports (inclusive)
109    pub const fn range(start: Port, end: Port) -> Self {
110        Self { start, end }
111    }
112
113    /// Create a range matching any port
114    pub const fn any() -> Self {
115        Self {
116            start: 0,
117            end: 65535,
118        }
119    }
120
121    /// Check if a port falls within this range
122    pub fn contains(&self, port: Port) -> bool {
123        port >= self.start && port <= self.end
124    }
125}
126
127impl Default for PortRange {
128    fn default() -> Self {
129        Self::any()
130    }
131}
132
133// ============================================================================
134// CIDR Address
135// ============================================================================
136
137/// IPv4 address with CIDR prefix length for subnet matching
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct CidrAddress {
140    /// Base address
141    pub address: Ipv4Address,
142    /// Prefix length (0-32)
143    pub prefix_len: u8,
144}
145
146impl CidrAddress {
147    /// Create a CIDR address
148    pub const fn new(address: Ipv4Address, prefix_len: u8) -> Self {
149        Self {
150            address,
151            prefix_len,
152        }
153    }
154
155    /// Create a CIDR that matches any address (0.0.0.0/0)
156    pub const fn any() -> Self {
157        Self {
158            address: Ipv4Address::ANY,
159            prefix_len: 0,
160        }
161    }
162
163    /// Create a CIDR that matches a single host (/32)
164    pub const fn host(address: Ipv4Address) -> Self {
165        Self {
166            address,
167            prefix_len: 32,
168        }
169    }
170
171    /// Compute the subnet mask as a u32
172    fn mask(&self) -> u32 {
173        if self.prefix_len == 0 {
174            0
175        } else if self.prefix_len >= 32 {
176            0xFFFF_FFFF
177        } else {
178            0xFFFF_FFFF << (32 - self.prefix_len)
179        }
180    }
181
182    /// Check if an address matches this CIDR block
183    pub fn matches(&self, addr: &Ipv4Address) -> bool {
184        let mask = self.mask();
185        (self.address.to_u32() & mask) == (addr.to_u32() & mask)
186    }
187}
188
189impl Default for CidrAddress {
190    fn default() -> Self {
191        Self::any()
192    }
193}
194
195// ============================================================================
196// Match Criteria
197// ============================================================================
198
199/// Criteria for matching packets against a firewall rule
200#[derive(Debug, Clone, Default)]
201pub struct MatchCriteria {
202    /// Source IP with CIDR mask (None = match any)
203    pub src_ip: Option<CidrAddress>,
204    /// Destination IP with CIDR mask (None = match any)
205    pub dst_ip: Option<CidrAddress>,
206    /// Source port range (None = match any)
207    pub src_port: Option<PortRange>,
208    /// Destination port range (None = match any)
209    pub dst_port: Option<PortRange>,
210    /// IP protocol (Any = match all protocols)
211    pub protocol: Protocol,
212    /// TCP flags to match (None = don't check flags)
213    pub tcp_flags: Option<TcpFlags>,
214    /// Connection tracking state (None = don't check state)
215    pub conn_state: Option<ConntrackState>,
216    /// Negate source IP match
217    pub negate_src: bool,
218    /// Negate destination IP match
219    pub negate_dst: bool,
220}
221
222impl MatchCriteria {
223    /// Create criteria matching everything
224    pub fn new() -> Self {
225        Self::default()
226    }
227
228    /// Set source IP CIDR
229    pub fn with_src_ip(mut self, cidr: CidrAddress) -> Self {
230        self.src_ip = Some(cidr);
231        self
232    }
233
234    /// Set destination IP CIDR
235    pub fn with_dst_ip(mut self, cidr: CidrAddress) -> Self {
236        self.dst_ip = Some(cidr);
237        self
238    }
239
240    /// Set source port range
241    pub fn with_src_port(mut self, range: PortRange) -> Self {
242        self.src_port = Some(range);
243        self
244    }
245
246    /// Set destination port range
247    pub fn with_dst_port(mut self, range: PortRange) -> Self {
248        self.dst_port = Some(range);
249        self
250    }
251
252    /// Set protocol
253    pub fn with_protocol(mut self, proto: Protocol) -> Self {
254        self.protocol = proto;
255        self
256    }
257
258    /// Set TCP flags match
259    pub fn with_tcp_flags(mut self, flags: TcpFlags) -> Self {
260        self.tcp_flags = Some(flags);
261        self
262    }
263
264    /// Set connection state match
265    pub fn with_conn_state(mut self, state: ConntrackState) -> Self {
266        self.conn_state = Some(state);
267        self
268    }
269}
270
271// ============================================================================
272// Rule Actions
273// ============================================================================
274
275/// Action to take when a rule matches
276#[derive(Debug, Clone, PartialEq, Eq, Default)]
277pub enum RuleAction {
278    /// Allow the packet
279    #[default]
280    Accept,
281    /// Silently drop the packet
282    Drop,
283    /// Drop and send ICMP unreachable
284    Reject,
285    /// Log the packet and continue evaluation
286    Log,
287    /// Jump to another chain
288    Jump(String),
289    /// Return from current chain to caller
290    Return,
291    /// Source NAT with masquerading (use outgoing interface address)
292    Masquerade,
293    /// Source NAT to a specific address
294    Snat(Ipv4Address),
295    /// Destination NAT to a specific address and port
296    Dnat(Ipv4Address, Port),
297}
298
299// ============================================================================
300// Packet Metadata
301// ============================================================================
302
303/// Extracted packet metadata used for rule evaluation
304///
305/// This avoids passing raw packet bytes -- the caller extracts relevant
306/// header fields before calling the firewall engine.
307#[derive(Debug, Clone)]
308pub struct PacketMetadata {
309    /// Source IPv4 address
310    pub src_ip: Ipv4Address,
311    /// Destination IPv4 address
312    pub dst_ip: Ipv4Address,
313    /// Source port (0 for ICMP)
314    pub src_port: Port,
315    /// Destination port (0 for ICMP)
316    pub dst_port: Port,
317    /// IP protocol
318    pub protocol: Protocol,
319    /// TCP flags (raw byte)
320    pub tcp_flags: u8,
321    /// Connection tracking state (if known)
322    pub conn_state: Option<ConntrackState>,
323    /// Total packet length in bytes
324    pub packet_len: u16,
325}
326
327impl Default for PacketMetadata {
328    fn default() -> Self {
329        Self {
330            src_ip: Ipv4Address::ANY,
331            dst_ip: Ipv4Address::ANY,
332            src_port: 0,
333            dst_port: 0,
334            protocol: Protocol::default(),
335            tcp_flags: 0,
336            conn_state: None,
337            packet_len: 0,
338        }
339    }
340}
341
342// ============================================================================
343// Firewall Rule
344// ============================================================================
345
346/// A single firewall rule with match criteria, action, and counters
347#[derive(Debug, Clone)]
348pub struct FirewallRule {
349    /// Unique rule identifier
350    pub id: u64,
351    /// Priority (lower = evaluated first within a chain)
352    pub priority: u32,
353    /// Match criteria
354    pub criteria: MatchCriteria,
355    /// Action to take on match
356    pub action: RuleAction,
357    /// Packet counter
358    pub packets: u64,
359    /// Byte counter
360    pub bytes: u64,
361    /// Whether this rule is active
362    pub enabled: bool,
363    /// Optional comment/description
364    pub comment: String,
365}
366
367impl FirewallRule {
368    /// Create a new rule with the given criteria and action
369    pub fn new(id: u64, criteria: MatchCriteria, action: RuleAction) -> Self {
370        Self {
371            id,
372            priority: 0,
373            criteria,
374            action,
375            packets: 0,
376            bytes: 0,
377            enabled: true,
378            comment: String::new(),
379        }
380    }
381
382    /// Set the rule priority
383    pub fn with_priority(mut self, priority: u32) -> Self {
384        self.priority = priority;
385        self
386    }
387
388    /// Set the rule comment
389    pub fn with_comment(mut self, comment: &str) -> Self {
390        self.comment = String::from(comment);
391        self
392    }
393
394    /// Check if this rule matches the given packet metadata
395    pub fn matches_packet(&self, meta: &PacketMetadata) -> bool {
396        // Protocol check
397        if self.criteria.protocol != Protocol::Any && self.criteria.protocol != meta.protocol {
398            return false;
399        }
400
401        // Source IP check
402        if let Some(ref cidr) = self.criteria.src_ip {
403            let matches = cidr.matches(&meta.src_ip);
404            if matches == self.criteria.negate_src {
405                return false;
406            }
407        }
408
409        // Destination IP check
410        if let Some(ref cidr) = self.criteria.dst_ip {
411            let matches = cidr.matches(&meta.dst_ip);
412            if matches == self.criteria.negate_dst {
413                return false;
414            }
415        }
416
417        // Source port check
418        if let Some(ref range) = self.criteria.src_port {
419            if !range.contains(meta.src_port) {
420                return false;
421            }
422        }
423
424        // Destination port check
425        if let Some(ref range) = self.criteria.dst_port {
426            if !range.contains(meta.dst_port) {
427                return false;
428            }
429        }
430
431        // TCP flags check
432        if let Some(ref flags) = self.criteria.tcp_flags {
433            if !flags.matches(meta.tcp_flags) {
434                return false;
435            }
436        }
437
438        // Connection state check
439        if let Some(ref expected_state) = self.criteria.conn_state {
440            match meta.conn_state {
441                Some(ref actual_state) if actual_state == expected_state => {}
442                _ => return false,
443            }
444        }
445
446        true
447    }
448
449    /// Reset packet/byte counters
450    pub fn reset_counters(&mut self) {
451        self.packets = 0;
452        self.bytes = 0;
453    }
454}
455
456// ============================================================================
457// Rule Engine
458// ============================================================================
459
460/// Manages all firewall rules and provides lookup by ID
461pub struct RuleEngine {
462    /// All rules indexed by ID
463    rules: BTreeMap<u64, FirewallRule>,
464    /// Next available rule ID
465    next_id: u64,
466}
467
468impl RuleEngine {
469    /// Create a new empty rule engine
470    pub fn new() -> Self {
471        Self {
472            rules: BTreeMap::new(),
473            next_id: 1,
474        }
475    }
476
477    /// Add a rule and return its assigned ID
478    pub fn add_rule(&mut self, mut rule: FirewallRule) -> u64 {
479        let id = self.next_id;
480        self.next_id += 1;
481        rule.id = id;
482        self.rules.insert(id, rule);
483        id
484    }
485
486    /// Remove a rule by ID
487    pub fn remove_rule(&mut self, id: u64) -> Option<FirewallRule> {
488        self.rules.remove(&id)
489    }
490
491    /// Get a rule by ID (immutable)
492    pub fn get_rule(&self, id: u64) -> Option<&FirewallRule> {
493        self.rules.get(&id)
494    }
495
496    /// Get a rule by ID (mutable)
497    pub fn get_rule_mut(&mut self, id: u64) -> Option<&mut FirewallRule> {
498        self.rules.get_mut(&id)
499    }
500
501    /// Number of rules
502    pub fn rule_count(&self) -> usize {
503        self.rules.len()
504    }
505
506    /// Evaluate a packet against a list of rule IDs, returning the first
507    /// matching rule's action
508    pub fn evaluate(&mut self, rule_ids: &[u64], metadata: &PacketMetadata) -> Option<RuleAction> {
509        for &id in rule_ids {
510            if let Some(rule) = self.rules.get_mut(&id) {
511                if rule.enabled && rule.matches_packet(metadata) {
512                    rule.packets += 1;
513                    rule.bytes += metadata.packet_len as u64;
514                    return Some(rule.action.clone());
515                }
516            }
517        }
518        None
519    }
520
521    /// Get all rules sorted by priority
522    pub fn rules_by_priority(&self) -> Vec<&FirewallRule> {
523        let mut rules: Vec<&FirewallRule> = self.rules.values().collect();
524        rules.sort_by_key(|r| r.priority);
525        rules
526    }
527}
528
529impl Default for RuleEngine {
530    fn default() -> Self {
531        Self::new()
532    }
533}
534
535// ============================================================================
536// Tests
537// ============================================================================
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    fn test_metadata() -> PacketMetadata {
544        PacketMetadata {
545            src_ip: Ipv4Address::new(192, 168, 1, 100),
546            dst_ip: Ipv4Address::new(10, 0, 0, 1),
547            src_port: 12345,
548            dst_port: 80,
549            protocol: Protocol::Tcp,
550            tcp_flags: TcpFlags::SYN,
551            conn_state: Some(ConntrackState::New),
552            packet_len: 64,
553        }
554    }
555
556    #[test]
557    fn test_protocol_default() {
558        assert_eq!(Protocol::default(), Protocol::Any);
559    }
560
561    #[test]
562    fn test_port_range_single() {
563        let range = PortRange::single(80);
564        assert!(range.contains(80));
565        assert!(!range.contains(81));
566        assert!(!range.contains(79));
567    }
568
569    #[test]
570    fn test_port_range_range() {
571        let range = PortRange::range(1024, 65535);
572        assert!(!range.contains(80));
573        assert!(range.contains(1024));
574        assert!(range.contains(50000));
575        assert!(range.contains(65535));
576    }
577
578    #[test]
579    fn test_port_range_any() {
580        let range = PortRange::any();
581        assert!(range.contains(0));
582        assert!(range.contains(80));
583        assert!(range.contains(65535));
584    }
585
586    #[test]
587    fn test_cidr_matches_slash32() {
588        let cidr = CidrAddress::host(Ipv4Address::new(192, 168, 1, 1));
589        assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 1)));
590        assert!(!cidr.matches(&Ipv4Address::new(192, 168, 1, 2)));
591    }
592
593    #[test]
594    fn test_cidr_matches_slash24() {
595        let cidr = CidrAddress::new(Ipv4Address::new(192, 168, 1, 0), 24);
596        assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 0)));
597        assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 255)));
598        assert!(!cidr.matches(&Ipv4Address::new(192, 168, 2, 1)));
599    }
600
601    #[test]
602    fn test_cidr_matches_slash0() {
603        let cidr = CidrAddress::any();
604        assert!(cidr.matches(&Ipv4Address::new(1, 2, 3, 4)));
605        assert!(cidr.matches(&Ipv4Address::new(255, 255, 255, 255)));
606    }
607
608    #[test]
609    fn test_cidr_matches_slash16() {
610        let cidr = CidrAddress::new(Ipv4Address::new(10, 0, 0, 0), 16);
611        assert!(cidr.matches(&Ipv4Address::new(10, 0, 0, 1)));
612        assert!(cidr.matches(&Ipv4Address::new(10, 0, 255, 255)));
613        assert!(!cidr.matches(&Ipv4Address::new(10, 1, 0, 1)));
614    }
615
616    #[test]
617    fn test_tcp_flags_syn_only() {
618        let flags = TcpFlags::syn_only();
619        assert!(flags.matches(TcpFlags::SYN));
620        assert!(!flags.matches(TcpFlags::SYN | TcpFlags::ACK));
621        assert!(!flags.matches(TcpFlags::ACK));
622    }
623
624    #[test]
625    fn test_tcp_flags_established() {
626        let flags = TcpFlags::established();
627        assert!(flags.matches(TcpFlags::ACK));
628        assert!(flags.matches(TcpFlags::SYN | TcpFlags::ACK));
629        assert!(!flags.matches(TcpFlags::SYN));
630    }
631
632    #[test]
633    fn test_rule_matches_all() {
634        let rule = FirewallRule::new(1, MatchCriteria::new(), RuleAction::Accept);
635        let meta = test_metadata();
636        assert!(rule.matches_packet(&meta));
637    }
638
639    #[test]
640    fn test_rule_matches_src_ip() {
641        let criteria = MatchCriteria::new()
642            .with_src_ip(CidrAddress::new(Ipv4Address::new(192, 168, 1, 0), 24));
643        let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
644        let meta = test_metadata();
645        assert!(rule.matches_packet(&meta));
646
647        let mut meta2 = test_metadata();
648        meta2.src_ip = Ipv4Address::new(10, 0, 0, 5);
649        assert!(!rule.matches_packet(&meta2));
650    }
651
652    #[test]
653    fn test_rule_matches_dst_port() {
654        let criteria = MatchCriteria::new()
655            .with_protocol(Protocol::Tcp)
656            .with_dst_port(PortRange::single(80));
657        let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
658        let meta = test_metadata();
659        assert!(rule.matches_packet(&meta));
660
661        let mut meta2 = test_metadata();
662        meta2.dst_port = 443;
663        assert!(!rule.matches_packet(&meta2));
664    }
665
666    #[test]
667    fn test_rule_matches_protocol() {
668        let criteria = MatchCriteria::new().with_protocol(Protocol::Udp);
669        let rule = FirewallRule::new(1, criteria, RuleAction::Drop);
670        let meta = test_metadata(); // TCP
671        assert!(!rule.matches_packet(&meta));
672    }
673
674    #[test]
675    fn test_rule_matches_conn_state() {
676        let criteria = MatchCriteria::new().with_conn_state(ConntrackState::Established);
677        let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
678        let meta = test_metadata(); // New
679        assert!(!rule.matches_packet(&meta));
680
681        let mut meta2 = test_metadata();
682        meta2.conn_state = Some(ConntrackState::Established);
683        assert!(rule.matches_packet(&meta2));
684    }
685
686    #[test]
687    fn test_rule_engine_add_evaluate() {
688        let mut engine = RuleEngine::new();
689        let criteria = MatchCriteria::new()
690            .with_protocol(Protocol::Tcp)
691            .with_dst_port(PortRange::single(80));
692        let rule = FirewallRule::new(0, criteria, RuleAction::Accept);
693        let id = engine.add_rule(rule);
694        assert_eq!(id, 1);
695        assert_eq!(engine.rule_count(), 1);
696
697        let meta = test_metadata();
698        let action = engine.evaluate(&[id], &meta);
699        assert_eq!(action, Some(RuleAction::Accept));
700
701        // Check counters
702        let r = engine.get_rule(id).unwrap();
703        assert_eq!(r.packets, 1);
704        assert_eq!(r.bytes, 64);
705    }
706
707    #[test]
708    fn test_rule_engine_no_match() {
709        let mut engine = RuleEngine::new();
710        let criteria = MatchCriteria::new().with_protocol(Protocol::Udp);
711        let rule = FirewallRule::new(0, criteria, RuleAction::Drop);
712        let id = engine.add_rule(rule);
713
714        let meta = test_metadata(); // TCP
715        let action = engine.evaluate(&[id], &meta);
716        assert_eq!(action, None);
717    }
718
719    #[test]
720    fn test_rule_engine_remove() {
721        let mut engine = RuleEngine::new();
722        let rule = FirewallRule::new(0, MatchCriteria::new(), RuleAction::Drop);
723        let id = engine.add_rule(rule);
724        assert_eq!(engine.rule_count(), 1);
725
726        engine.remove_rule(id);
727        assert_eq!(engine.rule_count(), 0);
728    }
729
730    #[test]
731    fn test_rule_disabled() {
732        let mut engine = RuleEngine::new();
733        let mut rule = FirewallRule::new(0, MatchCriteria::new(), RuleAction::Drop);
734        rule.enabled = false;
735        let id = engine.add_rule(rule);
736
737        let meta = test_metadata();
738        let action = engine.evaluate(&[id], &meta);
739        assert_eq!(action, None);
740    }
741}