1#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::collections::BTreeMap;
12#[cfg(feature = "alloc")]
13use alloc::string::String;
14#[cfg(feature = "alloc")]
15use alloc::vec::Vec;
16
17use super::conntrack::ConntrackState;
18use crate::net::{Ipv4Address, Port};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum Protocol {
27 #[default]
29 Any,
30 Tcp,
32 Udp,
34 Icmp,
36 Icmpv6,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
46pub struct TcpFlags {
47 pub mask: u8,
49 pub value: u8,
51}
52
53impl TcpFlags {
54 pub const SYN: u8 = 0x02;
55 pub const ACK: u8 = 0x10;
56 pub const FIN: u8 = 0x01;
57 pub const RST: u8 = 0x04;
58 pub const PSH: u8 = 0x08;
59 pub const URG: u8 = 0x20;
60
61 pub const fn new(mask: u8, value: u8) -> Self {
63 Self { mask, value }
64 }
65
66 pub const fn syn_only() -> Self {
68 Self {
69 mask: Self::SYN | Self::ACK,
70 value: Self::SYN,
71 }
72 }
73
74 pub const fn established() -> Self {
76 Self {
77 mask: Self::ACK,
78 value: Self::ACK,
79 }
80 }
81
82 pub fn matches(&self, flags: u8) -> bool {
84 (flags & self.mask) == self.value
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct PortRange {
95 pub start: Port,
96 pub end: Port,
97}
98
99impl PortRange {
100 pub const fn single(port: Port) -> Self {
102 Self {
103 start: port,
104 end: port,
105 }
106 }
107
108 pub const fn range(start: Port, end: Port) -> Self {
110 Self { start, end }
111 }
112
113 pub const fn any() -> Self {
115 Self {
116 start: 0,
117 end: 65535,
118 }
119 }
120
121 pub fn contains(&self, port: Port) -> bool {
123 port >= self.start && port <= self.end
124 }
125}
126
127impl Default for PortRange {
128 fn default() -> Self {
129 Self::any()
130 }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct CidrAddress {
140 pub address: Ipv4Address,
142 pub prefix_len: u8,
144}
145
146impl CidrAddress {
147 pub const fn new(address: Ipv4Address, prefix_len: u8) -> Self {
149 Self {
150 address,
151 prefix_len,
152 }
153 }
154
155 pub const fn any() -> Self {
157 Self {
158 address: Ipv4Address::ANY,
159 prefix_len: 0,
160 }
161 }
162
163 pub const fn host(address: Ipv4Address) -> Self {
165 Self {
166 address,
167 prefix_len: 32,
168 }
169 }
170
171 fn mask(&self) -> u32 {
173 if self.prefix_len == 0 {
174 0
175 } else if self.prefix_len >= 32 {
176 0xFFFF_FFFF
177 } else {
178 0xFFFF_FFFF << (32 - self.prefix_len)
179 }
180 }
181
182 pub fn matches(&self, addr: &Ipv4Address) -> bool {
184 let mask = self.mask();
185 (self.address.to_u32() & mask) == (addr.to_u32() & mask)
186 }
187}
188
189impl Default for CidrAddress {
190 fn default() -> Self {
191 Self::any()
192 }
193}
194
195#[derive(Debug, Clone, Default)]
201pub struct MatchCriteria {
202 pub src_ip: Option<CidrAddress>,
204 pub dst_ip: Option<CidrAddress>,
206 pub src_port: Option<PortRange>,
208 pub dst_port: Option<PortRange>,
210 pub protocol: Protocol,
212 pub tcp_flags: Option<TcpFlags>,
214 pub conn_state: Option<ConntrackState>,
216 pub negate_src: bool,
218 pub negate_dst: bool,
220}
221
222impl MatchCriteria {
223 pub fn new() -> Self {
225 Self::default()
226 }
227
228 pub fn with_src_ip(mut self, cidr: CidrAddress) -> Self {
230 self.src_ip = Some(cidr);
231 self
232 }
233
234 pub fn with_dst_ip(mut self, cidr: CidrAddress) -> Self {
236 self.dst_ip = Some(cidr);
237 self
238 }
239
240 pub fn with_src_port(mut self, range: PortRange) -> Self {
242 self.src_port = Some(range);
243 self
244 }
245
246 pub fn with_dst_port(mut self, range: PortRange) -> Self {
248 self.dst_port = Some(range);
249 self
250 }
251
252 pub fn with_protocol(mut self, proto: Protocol) -> Self {
254 self.protocol = proto;
255 self
256 }
257
258 pub fn with_tcp_flags(mut self, flags: TcpFlags) -> Self {
260 self.tcp_flags = Some(flags);
261 self
262 }
263
264 pub fn with_conn_state(mut self, state: ConntrackState) -> Self {
266 self.conn_state = Some(state);
267 self
268 }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Default)]
277pub enum RuleAction {
278 #[default]
280 Accept,
281 Drop,
283 Reject,
285 Log,
287 Jump(String),
289 Return,
291 Masquerade,
293 Snat(Ipv4Address),
295 Dnat(Ipv4Address, Port),
297}
298
299#[derive(Debug, Clone)]
308pub struct PacketMetadata {
309 pub src_ip: Ipv4Address,
311 pub dst_ip: Ipv4Address,
313 pub src_port: Port,
315 pub dst_port: Port,
317 pub protocol: Protocol,
319 pub tcp_flags: u8,
321 pub conn_state: Option<ConntrackState>,
323 pub packet_len: u16,
325}
326
327impl Default for PacketMetadata {
328 fn default() -> Self {
329 Self {
330 src_ip: Ipv4Address::ANY,
331 dst_ip: Ipv4Address::ANY,
332 src_port: 0,
333 dst_port: 0,
334 protocol: Protocol::default(),
335 tcp_flags: 0,
336 conn_state: None,
337 packet_len: 0,
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
348pub struct FirewallRule {
349 pub id: u64,
351 pub priority: u32,
353 pub criteria: MatchCriteria,
355 pub action: RuleAction,
357 pub packets: u64,
359 pub bytes: u64,
361 pub enabled: bool,
363 pub comment: String,
365}
366
367impl FirewallRule {
368 pub fn new(id: u64, criteria: MatchCriteria, action: RuleAction) -> Self {
370 Self {
371 id,
372 priority: 0,
373 criteria,
374 action,
375 packets: 0,
376 bytes: 0,
377 enabled: true,
378 comment: String::new(),
379 }
380 }
381
382 pub fn with_priority(mut self, priority: u32) -> Self {
384 self.priority = priority;
385 self
386 }
387
388 pub fn with_comment(mut self, comment: &str) -> Self {
390 self.comment = String::from(comment);
391 self
392 }
393
394 pub fn matches_packet(&self, meta: &PacketMetadata) -> bool {
396 if self.criteria.protocol != Protocol::Any && self.criteria.protocol != meta.protocol {
398 return false;
399 }
400
401 if let Some(ref cidr) = self.criteria.src_ip {
403 let matches = cidr.matches(&meta.src_ip);
404 if matches == self.criteria.negate_src {
405 return false;
406 }
407 }
408
409 if let Some(ref cidr) = self.criteria.dst_ip {
411 let matches = cidr.matches(&meta.dst_ip);
412 if matches == self.criteria.negate_dst {
413 return false;
414 }
415 }
416
417 if let Some(ref range) = self.criteria.src_port {
419 if !range.contains(meta.src_port) {
420 return false;
421 }
422 }
423
424 if let Some(ref range) = self.criteria.dst_port {
426 if !range.contains(meta.dst_port) {
427 return false;
428 }
429 }
430
431 if let Some(ref flags) = self.criteria.tcp_flags {
433 if !flags.matches(meta.tcp_flags) {
434 return false;
435 }
436 }
437
438 if let Some(ref expected_state) = self.criteria.conn_state {
440 match meta.conn_state {
441 Some(ref actual_state) if actual_state == expected_state => {}
442 _ => return false,
443 }
444 }
445
446 true
447 }
448
449 pub fn reset_counters(&mut self) {
451 self.packets = 0;
452 self.bytes = 0;
453 }
454}
455
456pub struct RuleEngine {
462 rules: BTreeMap<u64, FirewallRule>,
464 next_id: u64,
466}
467
468impl RuleEngine {
469 pub fn new() -> Self {
471 Self {
472 rules: BTreeMap::new(),
473 next_id: 1,
474 }
475 }
476
477 pub fn add_rule(&mut self, mut rule: FirewallRule) -> u64 {
479 let id = self.next_id;
480 self.next_id += 1;
481 rule.id = id;
482 self.rules.insert(id, rule);
483 id
484 }
485
486 pub fn remove_rule(&mut self, id: u64) -> Option<FirewallRule> {
488 self.rules.remove(&id)
489 }
490
491 pub fn get_rule(&self, id: u64) -> Option<&FirewallRule> {
493 self.rules.get(&id)
494 }
495
496 pub fn get_rule_mut(&mut self, id: u64) -> Option<&mut FirewallRule> {
498 self.rules.get_mut(&id)
499 }
500
501 pub fn rule_count(&self) -> usize {
503 self.rules.len()
504 }
505
506 pub fn evaluate(&mut self, rule_ids: &[u64], metadata: &PacketMetadata) -> Option<RuleAction> {
509 for &id in rule_ids {
510 if let Some(rule) = self.rules.get_mut(&id) {
511 if rule.enabled && rule.matches_packet(metadata) {
512 rule.packets += 1;
513 rule.bytes += metadata.packet_len as u64;
514 return Some(rule.action.clone());
515 }
516 }
517 }
518 None
519 }
520
521 pub fn rules_by_priority(&self) -> Vec<&FirewallRule> {
523 let mut rules: Vec<&FirewallRule> = self.rules.values().collect();
524 rules.sort_by_key(|r| r.priority);
525 rules
526 }
527}
528
529impl Default for RuleEngine {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535#[cfg(test)]
540mod tests {
541 use super::*;
542
543 fn test_metadata() -> PacketMetadata {
544 PacketMetadata {
545 src_ip: Ipv4Address::new(192, 168, 1, 100),
546 dst_ip: Ipv4Address::new(10, 0, 0, 1),
547 src_port: 12345,
548 dst_port: 80,
549 protocol: Protocol::Tcp,
550 tcp_flags: TcpFlags::SYN,
551 conn_state: Some(ConntrackState::New),
552 packet_len: 64,
553 }
554 }
555
556 #[test]
557 fn test_protocol_default() {
558 assert_eq!(Protocol::default(), Protocol::Any);
559 }
560
561 #[test]
562 fn test_port_range_single() {
563 let range = PortRange::single(80);
564 assert!(range.contains(80));
565 assert!(!range.contains(81));
566 assert!(!range.contains(79));
567 }
568
569 #[test]
570 fn test_port_range_range() {
571 let range = PortRange::range(1024, 65535);
572 assert!(!range.contains(80));
573 assert!(range.contains(1024));
574 assert!(range.contains(50000));
575 assert!(range.contains(65535));
576 }
577
578 #[test]
579 fn test_port_range_any() {
580 let range = PortRange::any();
581 assert!(range.contains(0));
582 assert!(range.contains(80));
583 assert!(range.contains(65535));
584 }
585
586 #[test]
587 fn test_cidr_matches_slash32() {
588 let cidr = CidrAddress::host(Ipv4Address::new(192, 168, 1, 1));
589 assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 1)));
590 assert!(!cidr.matches(&Ipv4Address::new(192, 168, 1, 2)));
591 }
592
593 #[test]
594 fn test_cidr_matches_slash24() {
595 let cidr = CidrAddress::new(Ipv4Address::new(192, 168, 1, 0), 24);
596 assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 0)));
597 assert!(cidr.matches(&Ipv4Address::new(192, 168, 1, 255)));
598 assert!(!cidr.matches(&Ipv4Address::new(192, 168, 2, 1)));
599 }
600
601 #[test]
602 fn test_cidr_matches_slash0() {
603 let cidr = CidrAddress::any();
604 assert!(cidr.matches(&Ipv4Address::new(1, 2, 3, 4)));
605 assert!(cidr.matches(&Ipv4Address::new(255, 255, 255, 255)));
606 }
607
608 #[test]
609 fn test_cidr_matches_slash16() {
610 let cidr = CidrAddress::new(Ipv4Address::new(10, 0, 0, 0), 16);
611 assert!(cidr.matches(&Ipv4Address::new(10, 0, 0, 1)));
612 assert!(cidr.matches(&Ipv4Address::new(10, 0, 255, 255)));
613 assert!(!cidr.matches(&Ipv4Address::new(10, 1, 0, 1)));
614 }
615
616 #[test]
617 fn test_tcp_flags_syn_only() {
618 let flags = TcpFlags::syn_only();
619 assert!(flags.matches(TcpFlags::SYN));
620 assert!(!flags.matches(TcpFlags::SYN | TcpFlags::ACK));
621 assert!(!flags.matches(TcpFlags::ACK));
622 }
623
624 #[test]
625 fn test_tcp_flags_established() {
626 let flags = TcpFlags::established();
627 assert!(flags.matches(TcpFlags::ACK));
628 assert!(flags.matches(TcpFlags::SYN | TcpFlags::ACK));
629 assert!(!flags.matches(TcpFlags::SYN));
630 }
631
632 #[test]
633 fn test_rule_matches_all() {
634 let rule = FirewallRule::new(1, MatchCriteria::new(), RuleAction::Accept);
635 let meta = test_metadata();
636 assert!(rule.matches_packet(&meta));
637 }
638
639 #[test]
640 fn test_rule_matches_src_ip() {
641 let criteria = MatchCriteria::new()
642 .with_src_ip(CidrAddress::new(Ipv4Address::new(192, 168, 1, 0), 24));
643 let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
644 let meta = test_metadata();
645 assert!(rule.matches_packet(&meta));
646
647 let mut meta2 = test_metadata();
648 meta2.src_ip = Ipv4Address::new(10, 0, 0, 5);
649 assert!(!rule.matches_packet(&meta2));
650 }
651
652 #[test]
653 fn test_rule_matches_dst_port() {
654 let criteria = MatchCriteria::new()
655 .with_protocol(Protocol::Tcp)
656 .with_dst_port(PortRange::single(80));
657 let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
658 let meta = test_metadata();
659 assert!(rule.matches_packet(&meta));
660
661 let mut meta2 = test_metadata();
662 meta2.dst_port = 443;
663 assert!(!rule.matches_packet(&meta2));
664 }
665
666 #[test]
667 fn test_rule_matches_protocol() {
668 let criteria = MatchCriteria::new().with_protocol(Protocol::Udp);
669 let rule = FirewallRule::new(1, criteria, RuleAction::Drop);
670 let meta = test_metadata(); assert!(!rule.matches_packet(&meta));
672 }
673
674 #[test]
675 fn test_rule_matches_conn_state() {
676 let criteria = MatchCriteria::new().with_conn_state(ConntrackState::Established);
677 let rule = FirewallRule::new(1, criteria, RuleAction::Accept);
678 let meta = test_metadata(); assert!(!rule.matches_packet(&meta));
680
681 let mut meta2 = test_metadata();
682 meta2.conn_state = Some(ConntrackState::Established);
683 assert!(rule.matches_packet(&meta2));
684 }
685
686 #[test]
687 fn test_rule_engine_add_evaluate() {
688 let mut engine = RuleEngine::new();
689 let criteria = MatchCriteria::new()
690 .with_protocol(Protocol::Tcp)
691 .with_dst_port(PortRange::single(80));
692 let rule = FirewallRule::new(0, criteria, RuleAction::Accept);
693 let id = engine.add_rule(rule);
694 assert_eq!(id, 1);
695 assert_eq!(engine.rule_count(), 1);
696
697 let meta = test_metadata();
698 let action = engine.evaluate(&[id], &meta);
699 assert_eq!(action, Some(RuleAction::Accept));
700
701 let r = engine.get_rule(id).unwrap();
703 assert_eq!(r.packets, 1);
704 assert_eq!(r.bytes, 64);
705 }
706
707 #[test]
708 fn test_rule_engine_no_match() {
709 let mut engine = RuleEngine::new();
710 let criteria = MatchCriteria::new().with_protocol(Protocol::Udp);
711 let rule = FirewallRule::new(0, criteria, RuleAction::Drop);
712 let id = engine.add_rule(rule);
713
714 let meta = test_metadata(); let action = engine.evaluate(&[id], &meta);
716 assert_eq!(action, None);
717 }
718
719 #[test]
720 fn test_rule_engine_remove() {
721 let mut engine = RuleEngine::new();
722 let rule = FirewallRule::new(0, MatchCriteria::new(), RuleAction::Drop);
723 let id = engine.add_rule(rule);
724 assert_eq!(engine.rule_count(), 1);
725
726 engine.remove_rule(id);
727 assert_eq!(engine.rule_count(), 0);
728 }
729
730 #[test]
731 fn test_rule_disabled() {
732 let mut engine = RuleEngine::new();
733 let mut rule = FirewallRule::new(0, MatchCriteria::new(), RuleAction::Drop);
734 rule.enabled = false;
735 let id = engine.add_rule(rule);
736
737 let meta = test_metadata();
738 let action = engine.evaluate(&[id], &meta);
739 assert_eq!(action, None);
740 }
741}