1#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::string::String;
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15use super::rules::{FirewallRule, PacketMetadata, RuleAction, RuleEngine};
16use crate::{error::KernelError, sync::once_lock::GlobalState};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum HookPoint {
25 PreRouting,
27 Input,
29 Forward,
31 Output,
33 PostRouting,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum ChainPolicy {
40 #[default]
42 Accept,
43 Drop,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum ChainType {
50 Filter,
52 Nat,
54 Mangle,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum Verdict {
61 #[default]
63 Accept,
64 Drop,
66 Reject,
68 Queue,
70}
71
72#[derive(Debug, Clone)]
78pub struct Chain {
79 pub name: String,
81 pub hook_point: HookPoint,
83 pub policy: ChainPolicy,
85 pub rule_ids: Vec<u64>,
87}
88
89impl Chain {
90 pub fn new(name: &str, hook_point: HookPoint, policy: ChainPolicy) -> Self {
92 Self {
93 name: String::from(name),
94 hook_point,
95 policy,
96 rule_ids: Vec::new(),
97 }
98 }
99
100 pub fn add_rule(&mut self, rule_id: u64) {
102 self.rule_ids.push(rule_id);
103 }
104
105 pub fn remove_rule(&mut self, rule_id: u64) -> bool {
107 if let Some(pos) = self.rule_ids.iter().position(|&id| id == rule_id) {
108 self.rule_ids.remove(pos);
109 true
110 } else {
111 false
112 }
113 }
114
115 pub fn insert_rule(&mut self, index: usize, rule_id: u64) {
117 let idx = index.min(self.rule_ids.len());
118 self.rule_ids.insert(idx, rule_id);
119 }
120
121 pub fn rule_count(&self) -> usize {
123 self.rule_ids.len()
124 }
125}
126
127#[derive(Debug, Clone)]
133pub struct FirewallTable {
134 pub table_type: ChainType,
136 pub chains: Vec<Chain>,
138}
139
140impl FirewallTable {
141 pub fn new_filter() -> Self {
143 Self {
144 table_type: ChainType::Filter,
145 chains: alloc::vec![
146 Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept),
147 Chain::new("FORWARD", HookPoint::Forward, ChainPolicy::Drop),
148 Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
149 ],
150 }
151 }
152
153 pub fn new_nat() -> Self {
155 Self {
156 table_type: ChainType::Nat,
157 chains: alloc::vec![
158 Chain::new("PREROUTING", HookPoint::PreRouting, ChainPolicy::Accept),
159 Chain::new("POSTROUTING", HookPoint::PostRouting, ChainPolicy::Accept),
160 Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
161 ],
162 }
163 }
164
165 pub fn new_mangle() -> Self {
167 Self {
168 table_type: ChainType::Mangle,
169 chains: alloc::vec![
170 Chain::new("PREROUTING", HookPoint::PreRouting, ChainPolicy::Accept),
171 Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept),
172 Chain::new("FORWARD", HookPoint::Forward, ChainPolicy::Accept),
173 Chain::new("OUTPUT", HookPoint::Output, ChainPolicy::Accept),
174 Chain::new("POSTROUTING", HookPoint::PostRouting, ChainPolicy::Accept),
175 ],
176 }
177 }
178
179 pub fn get_chain(&self, name: &str) -> Option<&Chain> {
181 self.chains.iter().find(|c| c.name == name)
182 }
183
184 pub fn get_chain_mut(&mut self, name: &str) -> Option<&mut Chain> {
186 self.chains.iter_mut().find(|c| c.name == name)
187 }
188
189 pub fn chains_for_hook(&self, hook: HookPoint) -> Vec<&Chain> {
191 self.chains
192 .iter()
193 .filter(|c| c.hook_point == hook)
194 .collect()
195 }
196}
197
198pub struct FirewallEngine {
204 pub filter: FirewallTable,
206 pub nat: FirewallTable,
208 pub mangle: FirewallTable,
210 pub rule_engine: RuleEngine,
212 pub total_packets: u64,
214 pub dropped_packets: u64,
216}
217
218impl FirewallEngine {
219 pub fn new() -> Self {
221 Self {
222 filter: FirewallTable::new_filter(),
223 nat: FirewallTable::new_nat(),
224 mangle: FirewallTable::new_mangle(),
225 rule_engine: RuleEngine::new(),
226 total_packets: 0,
227 dropped_packets: 0,
228 }
229 }
230
231 pub fn add_rule(&mut self, rule: FirewallRule) -> u64 {
233 self.rule_engine.add_rule(rule)
234 }
235
236 pub fn add_to_filter_chain(&mut self, chain_name: &str, rule_id: u64) -> bool {
238 if let Some(chain) = self.filter.get_chain_mut(chain_name) {
239 chain.add_rule(rule_id);
240 true
241 } else {
242 false
243 }
244 }
245
246 pub fn add_to_nat_chain(&mut self, chain_name: &str, rule_id: u64) -> bool {
248 if let Some(chain) = self.nat.get_chain_mut(chain_name) {
249 chain.add_rule(rule_id);
250 true
251 } else {
252 false
253 }
254 }
255
256 pub fn set_filter_policy(&mut self, chain_name: &str, policy: ChainPolicy) -> bool {
258 if let Some(chain) = self.filter.get_chain_mut(chain_name) {
259 chain.policy = policy;
260 true
261 } else {
262 false
263 }
264 }
265
266 pub fn process_packet(&mut self, hook: HookPoint, metadata: &PacketMetadata) -> Verdict {
271 self.total_packets += 1;
272
273 let mangle_verdict = self.evaluate_table_chains(&self.mangle, hook);
275 if mangle_verdict != Verdict::Accept {
276 if mangle_verdict == Verdict::Drop || mangle_verdict == Verdict::Reject {
277 self.dropped_packets += 1;
278 }
279 return mangle_verdict;
280 }
281
282 match hook {
284 HookPoint::PreRouting | HookPoint::PostRouting | HookPoint::Output => {
285 let nat_verdict = self.evaluate_table_chains(&self.nat, hook);
286 if nat_verdict != Verdict::Accept {
287 if nat_verdict == Verdict::Drop || nat_verdict == Verdict::Reject {
288 self.dropped_packets += 1;
289 }
290 return nat_verdict;
291 }
292 }
293 _ => {}
294 }
295
296 match hook {
298 HookPoint::Input | HookPoint::Forward | HookPoint::Output => {
299 let filter_verdict = self.evaluate_filter_chains(hook, metadata);
300 if filter_verdict != Verdict::Accept {
301 if filter_verdict == Verdict::Drop || filter_verdict == Verdict::Reject {
302 self.dropped_packets += 1;
303 }
304 return filter_verdict;
305 }
306 }
307 _ => {}
308 }
309
310 Verdict::Accept
311 }
312
313 fn evaluate_table_chains(&self, table: &FirewallTable, hook: HookPoint) -> Verdict {
315 for chain in table.chains_for_hook(hook) {
316 if chain.rule_ids.is_empty() {
317 continue;
319 }
320 match chain.policy {
324 ChainPolicy::Drop => return Verdict::Drop,
325 ChainPolicy::Accept => continue,
326 }
327 }
328 Verdict::Accept
329 }
330
331 fn evaluate_filter_chains(&mut self, hook: HookPoint, metadata: &PacketMetadata) -> Verdict {
333 let chains: Vec<(ChainPolicy, Vec<u64>)> = self
335 .filter
336 .chains_for_hook(hook)
337 .iter()
338 .map(|c| (c.policy, c.rule_ids.clone()))
339 .collect();
340
341 for (policy, rule_ids) in &chains {
342 for &rule_id in rule_ids {
343 if let Some(rule) = self.rule_engine.get_rule_mut(rule_id) {
344 if !rule.enabled {
345 continue;
346 }
347 if rule.matches_packet(metadata) {
348 rule.packets += 1;
349 rule.bytes += metadata.packet_len as u64;
350 match rule.action {
351 RuleAction::Accept => return Verdict::Accept,
352 RuleAction::Drop => return Verdict::Drop,
353 RuleAction::Reject => return Verdict::Reject,
354 RuleAction::Log => continue, RuleAction::Return => break, _ => return Verdict::Accept,
357 }
358 }
359 }
360 }
361
362 match policy {
364 ChainPolicy::Accept => {} ChainPolicy::Drop => return Verdict::Drop,
366 }
367 }
368
369 Verdict::Accept
370 }
371}
372
373impl Default for FirewallEngine {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379static FIREWALL_ENGINE: GlobalState<spin::Mutex<FirewallEngine>> = GlobalState::new();
384
385pub fn init() -> Result<(), KernelError> {
387 FIREWALL_ENGINE
388 .init(spin::Mutex::new(FirewallEngine::new()))
389 .map_err(|_| KernelError::InvalidAddress { addr: 0 })?;
390 Ok(())
391}
392
393pub fn with_engine<R, F: FnOnce(&mut FirewallEngine) -> R>(f: F) -> Option<R> {
395 FIREWALL_ENGINE.with(|lock| {
396 let mut engine = lock.lock();
397 f(&mut engine)
398 })
399}
400
401#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_hook_point_equality() {
411 assert_eq!(HookPoint::Input, HookPoint::Input);
412 assert_ne!(HookPoint::Input, HookPoint::Output);
413 }
414
415 #[test]
416 fn test_chain_policy_default() {
417 assert_eq!(ChainPolicy::default(), ChainPolicy::Accept);
418 }
419
420 #[test]
421 fn test_verdict_default() {
422 assert_eq!(Verdict::default(), Verdict::Accept);
423 }
424
425 #[test]
426 fn test_chain_new() {
427 let chain = Chain::new("TEST", HookPoint::Input, ChainPolicy::Drop);
428 assert_eq!(chain.name, "TEST");
429 assert_eq!(chain.hook_point, HookPoint::Input);
430 assert_eq!(chain.policy, ChainPolicy::Drop);
431 assert_eq!(chain.rule_count(), 0);
432 }
433
434 #[test]
435 fn test_chain_add_remove_rule() {
436 let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
437 chain.add_rule(1);
438 chain.add_rule(2);
439 chain.add_rule(3);
440 assert_eq!(chain.rule_count(), 3);
441
442 assert!(chain.remove_rule(2));
443 assert_eq!(chain.rule_count(), 2);
444 assert!(!chain.remove_rule(99));
445 }
446
447 #[test]
448 fn test_chain_insert_rule() {
449 let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
450 chain.add_rule(10);
451 chain.add_rule(30);
452 chain.insert_rule(1, 20);
453 assert_eq!(chain.rule_ids, alloc::vec![10, 20, 30]);
454 }
455
456 #[test]
457 fn test_chain_insert_clamped() {
458 let mut chain = Chain::new("INPUT", HookPoint::Input, ChainPolicy::Accept);
459 chain.add_rule(1);
460 chain.insert_rule(100, 2); assert_eq!(chain.rule_ids, alloc::vec![1, 2]);
462 }
463
464 #[test]
465 fn test_filter_table_default_chains() {
466 let table = FirewallTable::new_filter();
467 assert_eq!(table.table_type, ChainType::Filter);
468 assert_eq!(table.chains.len(), 3);
469 assert!(table.get_chain("INPUT").is_some());
470 assert!(table.get_chain("FORWARD").is_some());
471 assert!(table.get_chain("OUTPUT").is_some());
472 assert!(table.get_chain("NONEXISTENT").is_none());
473 }
474
475 #[test]
476 fn test_nat_table_default_chains() {
477 let table = FirewallTable::new_nat();
478 assert_eq!(table.table_type, ChainType::Nat);
479 assert_eq!(table.chains.len(), 3);
480 assert!(table.get_chain("PREROUTING").is_some());
481 assert!(table.get_chain("POSTROUTING").is_some());
482 assert!(table.get_chain("OUTPUT").is_some());
483 }
484
485 #[test]
486 fn test_mangle_table_default_chains() {
487 let table = FirewallTable::new_mangle();
488 assert_eq!(table.table_type, ChainType::Mangle);
489 assert_eq!(table.chains.len(), 5);
490 }
491
492 #[test]
493 fn test_table_chains_for_hook() {
494 let table = FirewallTable::new_filter();
495 let input_chains = table.chains_for_hook(HookPoint::Input);
496 assert_eq!(input_chains.len(), 1);
497 assert_eq!(input_chains[0].name, "INPUT");
498
499 let prerouting_chains = table.chains_for_hook(HookPoint::PreRouting);
500 assert_eq!(prerouting_chains.len(), 0);
501 }
502
503 #[test]
504 fn test_engine_default() {
505 let engine = FirewallEngine::new();
506 assert_eq!(engine.total_packets, 0);
507 assert_eq!(engine.dropped_packets, 0);
508 assert_eq!(engine.filter.chains.len(), 3);
509 assert_eq!(engine.nat.chains.len(), 3);
510 assert_eq!(engine.mangle.chains.len(), 5);
511 }
512
513 #[test]
514 fn test_engine_process_accept_default() {
515 let mut engine = FirewallEngine::new();
516 let metadata = PacketMetadata::default();
517 let verdict = engine.process_packet(HookPoint::Input, &metadata);
518 assert_eq!(verdict, Verdict::Accept);
519 assert_eq!(engine.total_packets, 1);
520 assert_eq!(engine.dropped_packets, 0);
521 }
522}