⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/firewall/
chain.rs

1//! Firewall chain and table management
2//!
3//! Implements the netfilter-style chain architecture with five hook points
4//! (PreRouting, Input, Forward, Output, PostRouting), three table types
5//! (Filter, Nat, Mangle), and a packet processing engine that evaluates
6//! rules in priority order and returns a verdict.
7
8#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::string::String;
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use super::rules::{FirewallRule, PacketMetadata, RuleAction, RuleEngine};
16use crate::{error::KernelError, sync::once_lock::GlobalState};
17
18// ============================================================================
19// Hook Points & Enums
20// ============================================================================
21
22/// Netfilter-style hook points in the packet processing path
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum HookPoint {
25    /// Before routing decision (inbound packets)
26    PreRouting,
27    /// Destined for the local host
28    Input,
29    /// Being forwarded to another interface
30    Forward,
31    /// Generated by the local host
32    Output,
33    /// After routing decision (outbound packets)
34    PostRouting,
35}
36
37/// Default policy for a chain when no rules match
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum ChainPolicy {
40    /// Allow the packet through
41    #[default]
42    Accept,
43    /// Silently discard the packet
44    Drop,
45}
46
47/// Type of firewall table
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum ChainType {
50    /// Packet filtering (accept/drop/reject)
51    Filter,
52    /// Network address translation
53    Nat,
54    /// Packet header modification
55    Mangle,
56}
57
58/// Result of processing a packet through the firewall
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum Verdict {
61    /// Allow the packet
62    #[default]
63    Accept,
64    /// Silently drop the packet
65    Drop,
66    /// Drop and send ICMP unreachable
67    Reject,
68    /// Send to userspace queue for inspection
69    Queue,
70}
71
72// ============================================================================
73// Chain
74// ============================================================================
75
76/// A chain of firewall rules evaluated in order
77#[derive(Debug, Clone)]
78pub struct Chain {
79    /// Name of this chain (e.g., "INPUT", "FORWARD", "custom_chain")
80    pub name: String,
81    /// Hook point this chain is attached to
82    pub hook_point: HookPoint,
83    /// Default policy when no rules match
84    pub policy: ChainPolicy,
85    /// Rule IDs in evaluation order (lower index = higher priority)
86    pub rule_ids: Vec<u64>,
87}
88
89impl Chain {
90    /// Create a new chain with the given name and hook point
91    pub fn new(name: &str, hook_point: HookPoint, policy: ChainPolicy) -> Self {
92        Self {
93            name: String::from(name),
94            hook_point,
95            policy,
96            rule_ids: Vec::new(),
97        }
98    }
99
100    /// Append a rule ID to this chain
101    pub fn add_rule(&mut self, rule_id: u64) {
102        self.rule_ids.push(rule_id);
103    }
104
105    /// Remove a rule ID from this chain
106    pub fn remove_rule(&mut self, rule_id: u64) -> bool {
107        if let Some(pos) = self.rule_ids.iter().position(|&id| id == rule_id) {
108            self.rule_ids.remove(pos);
109            true
110        } else {
111            false
112        }
113    }
114
115    /// Insert a rule ID at the given position
116    pub fn insert_rule(&mut self, index: usize, rule_id: u64) {
117        let idx = index.min(self.rule_ids.len());
118        self.rule_ids.insert(idx, rule_id);
119    }
120
121    /// Number of rules in this chain
122    pub fn rule_count(&self) -> usize {
123        self.rule_ids.len()
124    }
125}
126
127// ============================================================================
128// Firewall Table
129// ============================================================================
130
131/// A firewall table containing chains for a specific purpose
132#[derive(Debug, Clone)]
133pub struct FirewallTable {
134    /// Table type
135    pub table_type: ChainType,
136    /// Chains belonging to this table
137    pub chains: Vec<Chain>,
138}
139
140impl FirewallTable {
141    /// Create a new filter table with default INPUT, FORWARD, OUTPUT chains
142    pub fn new_filter() -> Self {
143        Self {
144            table_type: ChainType::Filter,
145            chains: alloc::vec![
146                Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept),
147                Chain::new("FORWARD", HookPoint::Forward, ChainPolicy::Drop),
148                Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
149            ],
150        }
151    }
152
153    /// Create a new NAT table with default PREROUTING, POSTROUTING chains
154    pub fn new_nat() -> Self {
155        Self {
156            table_type: ChainType::Nat,
157            chains: alloc::vec![
158                Chain::new("PREROUTING", HookPoint::PreRouting, ChainPolicy::Accept),
159                Chain::new("POSTROUTING", HookPoint::PostRouting, ChainPolicy::Accept),
160                Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
161            ],
162        }
163    }
164
165    /// Create a new mangle table with all five hook points
166    pub fn new_mangle() -> Self {
167        Self {
168            table_type: ChainType::Mangle,
169            chains: alloc::vec![
170                Chain::new("PREROUTING", HookPoint::PreRouting, ChainPolicy::Accept),
171                Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept),
172                Chain::new("FORWARD", HookPoint::Forward, ChainPolicy::Accept),
173                Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
174                Chain::new("POSTROUTING", HookPoint::PostRouting, ChainPolicy::Accept),
175            ],
176        }
177    }
178
179    /// Find a chain by name
180    pub fn get_chain(&self, name: &str) -> Option<&Chain> {
181        self.chains.iter().find(|c| c.name == name)
182    }
183
184    /// Find a mutable chain by name
185    pub fn get_chain_mut(&mut self, name: &str) -> Option<&mut Chain> {
186        self.chains.iter_mut().find(|c| c.name == name)
187    }
188
189    /// Get all chains for a specific hook point
190    pub fn chains_for_hook(&self, hook: HookPoint) -> Vec<&Chain> {
191        self.chains
192            .iter()
193            .filter(|c| c.hook_point == hook)
194            .collect()
195    }
196}
197
198// ============================================================================
199// Firewall Engine
200// ============================================================================
201
202/// The main firewall engine that manages tables and processes packets
203pub struct FirewallEngine {
204    /// Filter table
205    pub filter: FirewallTable,
206    /// NAT table
207    pub nat: FirewallTable,
208    /// Mangle table
209    pub mangle: FirewallTable,
210    /// Rule engine for evaluating individual rules
211    pub rule_engine: RuleEngine,
212    /// Packet counter
213    pub total_packets: u64,
214    /// Dropped packet counter
215    pub dropped_packets: u64,
216}
217
218impl FirewallEngine {
219    /// Create a new firewall engine with default tables
220    pub fn new() -> Self {
221        Self {
222            filter: FirewallTable::new_filter(),
223            nat: FirewallTable::new_nat(),
224            mangle: FirewallTable::new_mangle(),
225            rule_engine: RuleEngine::new(),
226            total_packets: 0,
227            dropped_packets: 0,
228        }
229    }
230
231    /// Add a rule to the rule engine and return its ID
232    pub fn add_rule(&mut self, rule: FirewallRule) -> u64 {
233        self.rule_engine.add_rule(rule)
234    }
235
236    /// Add a rule ID to a specific chain in the filter table
237    pub fn add_to_filter_chain(&mut self, chain_name: &str, rule_id: u64) -> bool {
238        if let Some(chain) = self.filter.get_chain_mut(chain_name) {
239            chain.add_rule(rule_id);
240            true
241        } else {
242            false
243        }
244    }
245
246    /// Add a rule ID to a specific chain in the NAT table
247    pub fn add_to_nat_chain(&mut self, chain_name: &str, rule_id: u64) -> bool {
248        if let Some(chain) = self.nat.get_chain_mut(chain_name) {
249            chain.add_rule(rule_id);
250            true
251        } else {
252            false
253        }
254    }
255
256    /// Set the policy for a chain in the filter table
257    pub fn set_filter_policy(&mut self, chain_name: &str, policy: ChainPolicy) -> bool {
258        if let Some(chain) = self.filter.get_chain_mut(chain_name) {
259            chain.policy = policy;
260            true
261        } else {
262            false
263        }
264    }
265
266    /// Process a packet through chains at a specific hook point
267    ///
268    /// Evaluates mangle first, then filter (or nat for PreRouting/PostRouting).
269    /// Returns the final verdict.
270    pub fn process_packet(&mut self, hook: HookPoint, metadata: &PacketMetadata) -> Verdict {
271        self.total_packets += 1;
272
273        // Phase 1: Mangle table
274        let mangle_verdict = self.evaluate_table_chains(&self.mangle, hook);
275        if mangle_verdict != Verdict::Accept {
276            if mangle_verdict == Verdict::Drop || mangle_verdict == Verdict::Reject {
277                self.dropped_packets += 1;
278            }
279            return mangle_verdict;
280        }
281
282        // Phase 2: NAT table (only for PreRouting, PostRouting, Output)
283        match hook {
284            HookPoint::PreRouting | HookPoint::PostRouting | HookPoint::Output => {
285                let nat_verdict = self.evaluate_table_chains(&self.nat, hook);
286                if nat_verdict != Verdict::Accept {
287                    if nat_verdict == Verdict::Drop || nat_verdict == Verdict::Reject {
288                        self.dropped_packets += 1;
289                    }
290                    return nat_verdict;
291                }
292            }
293            _ => {}
294        }
295
296        // Phase 3: Filter table (for Input, Forward, Output)
297        match hook {
298            HookPoint::Input | HookPoint::Forward | HookPoint::Output => {
299                let filter_verdict = self.evaluate_filter_chains(hook, metadata);
300                if filter_verdict != Verdict::Accept {
301                    if filter_verdict == Verdict::Drop || filter_verdict == Verdict::Reject {
302                        self.dropped_packets += 1;
303                    }
304                    return filter_verdict;
305                }
306            }
307            _ => {}
308        }
309
310        Verdict::Accept
311    }
312
313    /// Evaluate all chains in a table for a given hook point (no rule matching)
314    fn evaluate_table_chains(&self, table: &FirewallTable, hook: HookPoint) -> Verdict {
315        for chain in table.chains_for_hook(hook) {
316            if chain.rule_ids.is_empty() {
317                // No rules -- apply chain default policy
318                continue;
319            }
320            // Chain has rules but no metadata-based evaluation here.
321            // Real evaluation happens in evaluate_filter_chains.
322            // Apply chain policy for non-filter tables.
323            match chain.policy {
324                ChainPolicy::Drop => return Verdict::Drop,
325                ChainPolicy::Accept => continue,
326            }
327        }
328        Verdict::Accept
329    }
330
331    /// Evaluate filter chains with full rule matching
332    fn evaluate_filter_chains(&mut self, hook: HookPoint, metadata: &PacketMetadata) -> Verdict {
333        // Clone rule_ids to avoid borrow conflict
334        let chains: Vec<(ChainPolicy, Vec<u64>)> = self
335            .filter
336            .chains_for_hook(hook)
337            .iter()
338            .map(|c| (c.policy, c.rule_ids.clone()))
339            .collect();
340
341        for (policy, rule_ids) in &chains {
342            for &rule_id in rule_ids {
343                if let Some(rule) = self.rule_engine.get_rule_mut(rule_id) {
344                    if !rule.enabled {
345                        continue;
346                    }
347                    if rule.matches_packet(metadata) {
348                        rule.packets += 1;
349                        rule.bytes += metadata.packet_len as u64;
350                        match rule.action {
351                            RuleAction::Accept => return Verdict::Accept,
352                            RuleAction::Drop => return Verdict::Drop,
353                            RuleAction::Reject => return Verdict::Reject,
354                            RuleAction::Log => continue, // Log and continue evaluation
355                            RuleAction::Return => break, // Return to calling chain
356                            _ => return Verdict::Accept,
357                        }
358                    }
359                }
360            }
361
362            // No rule matched (or Return action) -- apply chain default policy
363            match policy {
364                ChainPolicy::Accept => {} // continue to next chain
365                ChainPolicy::Drop => return Verdict::Drop,
366            }
367        }
368
369        Verdict::Accept
370    }
371}
372
373impl Default for FirewallEngine {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379// ============================================================================
380// Global State
381// ============================================================================
382
383static FIREWALL_ENGINE: GlobalState<spin::Mutex<FirewallEngine>> = GlobalState::new();
384
385/// Initialize the firewall chain subsystem
386pub fn init() -> Result<(), KernelError> {
387    FIREWALL_ENGINE
388        .init(spin::Mutex::new(FirewallEngine::new()))
389        .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
390    Ok(())
391}
392
393/// Access the global firewall engine
394pub fn with_engine<R, F: FnOnce(&mut FirewallEngine) -> R>(f: F) -> Option<R> {
395    FIREWALL_ENGINE.with(|lock| {
396        let mut engine = lock.lock();
397        f(&mut engine)
398    })
399}
400
401// ============================================================================
402// Tests
403// ============================================================================
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_hook_point_equality() {
411        assert_eq!(HookPoint::Input, HookPoint::Input);
412        assert_ne!(HookPoint::Input, HookPoint::Output);
413    }
414
415    #[test]
416    fn test_chain_policy_default() {
417        assert_eq!(ChainPolicy::default(), ChainPolicy::Accept);
418    }
419
420    #[test]
421    fn test_verdict_default() {
422        assert_eq!(Verdict::default(), Verdict::Accept);
423    }
424
425    #[test]
426    fn test_chain_new() {
427        let chain = Chain::new("TEST", HookPoint::Input, ChainPolicy::Drop);
428        assert_eq!(chain.name, "TEST");
429        assert_eq!(chain.hook_point, HookPoint::Input);
430        assert_eq!(chain.policy, ChainPolicy::Drop);
431        assert_eq!(chain.rule_count(), 0);
432    }
433
434    #[test]
435    fn test_chain_add_remove_rule() {
436        let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
437        chain.add_rule(1);
438        chain.add_rule(2);
439        chain.add_rule(3);
440        assert_eq!(chain.rule_count(), 3);
441
442        assert!(chain.remove_rule(2));
443        assert_eq!(chain.rule_count(), 2);
444        assert!(!chain.remove_rule(99));
445    }
446
447    #[test]
448    fn test_chain_insert_rule() {
449        let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
450        chain.add_rule(10);
451        chain.add_rule(30);
452        chain.insert_rule(1, 20);
453        assert_eq!(chain.rule_ids, alloc::vec![10, 20, 30]);
454    }
455
456    #[test]
457    fn test_chain_insert_clamped() {
458        let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
459        chain.add_rule(1);
460        chain.insert_rule(100, 2); // Index beyond length -> appended
461        assert_eq!(chain.rule_ids, alloc::vec![1, 2]);
462    }
463
464    #[test]
465    fn test_filter_table_default_chains() {
466        let table = FirewallTable::new_filter();
467        assert_eq!(table.table_type, ChainType::Filter);
468        assert_eq!(table.chains.len(), 3);
469        assert!(table.get_chain("INPUT").is_some());
470        assert!(table.get_chain("FORWARD").is_some());
471        assert!(table.get_chain("OUTPUT").is_some());
472        assert!(table.get_chain("NONEXISTENT").is_none());
473    }
474
475    #[test]
476    fn test_nat_table_default_chains() {
477        let table = FirewallTable::new_nat();
478        assert_eq!(table.table_type, ChainType::Nat);
479        assert_eq!(table.chains.len(), 3);
480        assert!(table.get_chain("PREROUTING").is_some());
481        assert!(table.get_chain("POSTROUTING").is_some());
482        assert!(table.get_chain("OUTPUT").is_some());
483    }
484
485    #[test]
486    fn test_mangle_table_default_chains() {
487        let table = FirewallTable::new_mangle();
488        assert_eq!(table.table_type, ChainType::Mangle);
489        assert_eq!(table.chains.len(), 5);
490    }
491
492    #[test]
493    fn test_table_chains_for_hook() {
494        let table = FirewallTable::new_filter();
495        let input_chains = table.chains_for_hook(HookPoint::Input);
496        assert_eq!(input_chains.len(), 1);
497        assert_eq!(input_chains[0].name, "INPUT");
498
499        let prerouting_chains = table.chains_for_hook(HookPoint::PreRouting);
500        assert_eq!(prerouting_chains.len(), 0);
501    }
502
503    #[test]
504    fn test_engine_default() {
505        let engine = FirewallEngine::new();
506        assert_eq!(engine.total_packets, 0);
507        assert_eq!(engine.dropped_packets, 0);
508        assert_eq!(engine.filter.chains.len(), 3);
509        assert_eq!(engine.nat.chains.len(), 3);
510        assert_eq!(engine.mangle.chains.len(), 5);
511    }
512
513    #[test]
514    fn test_engine_process_accept_default() {
515        let mut engine = FirewallEngine::new();
516        let metadata = PacketMetadata::default();
517        let verdict = engine.process_packet(HookPoint::Input, &metadata);
518        assert_eq!(verdict, Verdict::Accept);
519        assert_eq!(engine.total_packets, 1);
520        assert_eq!(engine.dropped_packets, 0);
521    }
522}