1#![allow(dead_code)] use alloc::{collections::BTreeMap, vec::Vec};
10
11use spin::Mutex;
12
13use crate::sync::once_lock::OnceLock;
14
15pub const IGMP_MEMBERSHIP_QUERY: u8 = 0x11;
22pub const IGMP_MEMBERSHIP_REPORT: u8 = 0x16;
24pub const IGMP_LEAVE_GROUP: u8 = 0x17;
26
27pub const MLD_QUERY: u8 = 130;
30pub const MLD_REPORT_V2: u8 = 143;
32
33pub const MLD_RECORD_IS_IN: u8 = 1;
36pub const MLD_RECORD_IS_EX: u8 = 2;
38pub const MLD_RECORD_TO_IN: u8 = 3;
40pub const MLD_RECORD_TO_EX: u8 = 4;
42pub const MLD_RECORD_ALLOW: u8 = 5;
44pub const MLD_RECORD_BLOCK: u8 = 6;
46
47const UNSOLICITED_REPORT_INTERVAL: u64 = 1000;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum MulticastError {
57 InvalidAddress,
59 GroupNotFound,
61 AlreadyMember,
63 GroupLimitReached,
65 MalformedMessage,
67 NotInitialized,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77pub struct MulticastGroup {
78 pub address: [u8; 4],
80 pub interface_index: u32,
82}
83
84impl MulticastGroup {
85 pub fn new(address: [u8; 4], interface_index: u32) -> Result<Self, MulticastError> {
87 if !is_ipv4_multicast(&address) {
88 return Err(MulticastError::InvalidAddress);
89 }
90 Ok(Self {
91 address,
92 interface_index,
93 })
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
99pub struct MulticastGroupV6 {
100 pub address: [u8; 16],
102 pub interface_index: u32,
104}
105
106impl MulticastGroupV6 {
107 pub fn new(address: [u8; 16], interface_index: u32) -> Result<Self, MulticastError> {
109 if !is_ipv6_multicast(&address) {
110 return Err(MulticastError::InvalidAddress);
111 }
112 Ok(Self {
113 address,
114 interface_index,
115 })
116 }
117}
118
119pub fn is_ipv4_multicast(addr: &[u8; 4]) -> bool {
121 addr[0] & 0xF0 == 224
122}
123
124pub fn is_ipv6_multicast(addr: &[u8; 16]) -> bool {
126 addr[0] == 0xFF
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct IgmpMessage {
136 pub msg_type: u8,
138 pub max_resp_time: u8,
140 pub checksum: u16,
142 pub group_address: [u8; 4],
144}
145
146impl IgmpMessage {
147 pub const WIRE_SIZE: usize = 8;
149
150 pub fn new(msg_type: u8, max_resp_time: u8, group_address: [u8; 4]) -> Self {
152 let mut msg = Self {
153 msg_type,
154 max_resp_time,
155 checksum: 0,
156 group_address,
157 };
158 msg.checksum = msg.compute_checksum();
159 msg
160 }
161
162 pub fn to_bytes(&self) -> [u8; Self::WIRE_SIZE] {
164 let mut buf = [0u8; Self::WIRE_SIZE];
165 buf[0] = self.msg_type;
166 buf[1] = self.max_resp_time;
167 buf[2] = (self.checksum >> 8) as u8;
168 buf[3] = self.checksum as u8;
169 buf[4..8].copy_from_slice(&self.group_address);
170 buf
171 }
172
173 pub fn from_bytes(data: &[u8]) -> Result<Self, MulticastError> {
175 if data.len() < Self::WIRE_SIZE {
176 return Err(MulticastError::MalformedMessage);
177 }
178 Ok(Self {
179 msg_type: data[0],
180 max_resp_time: data[1],
181 checksum: u16::from_be_bytes([data[2], data[3]]),
182 group_address: [data[4], data[5], data[6], data[7]],
183 })
184 }
185
186 pub fn compute_checksum(&self) -> u16 {
188 let mut bytes = self.to_bytes();
189 bytes[2] = 0;
191 bytes[3] = 0;
192 internet_checksum(&bytes)
193 }
194
195 pub fn verify_checksum(&self) -> bool {
197 let bytes = self.to_bytes();
198 internet_checksum(&bytes) == 0
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub struct MldMessage {
209 pub msg_type: u8,
211 pub code: u8,
213 pub checksum: u16,
215 pub max_resp_delay: u16,
217 pub reserved: u16,
219 pub multicast_address: [u8; 16],
221}
222
223impl MldMessage {
224 pub const HEADER_SIZE: usize = 24;
226
227 pub fn new_query(max_resp_delay: u16, multicast_address: [u8; 16]) -> Self {
229 Self {
230 msg_type: MLD_QUERY,
231 code: 0,
232 checksum: 0,
233 max_resp_delay,
234 reserved: 0,
235 multicast_address,
236 }
237 }
238
239 pub fn new_report() -> Self {
241 Self {
242 msg_type: MLD_REPORT_V2,
243 code: 0,
244 checksum: 0,
245 max_resp_delay: 0,
246 reserved: 0,
247 multicast_address: [0; 16],
248 }
249 }
250
251 pub fn to_bytes(&self) -> [u8; Self::HEADER_SIZE] {
253 let mut buf = [0u8; Self::HEADER_SIZE];
254 buf[0] = self.msg_type;
255 buf[1] = self.code;
256 buf[2] = (self.checksum >> 8) as u8;
257 buf[3] = self.checksum as u8;
258 buf[4] = (self.max_resp_delay >> 8) as u8;
259 buf[5] = self.max_resp_delay as u8;
260 buf[6] = (self.reserved >> 8) as u8;
261 buf[7] = self.reserved as u8;
262 buf[8..24].copy_from_slice(&self.multicast_address);
263 buf
264 }
265
266 pub fn from_bytes(data: &[u8]) -> Result<Self, MulticastError> {
268 if data.len() < Self::HEADER_SIZE {
269 return Err(MulticastError::MalformedMessage);
270 }
271 let mut addr = [0u8; 16];
272 addr.copy_from_slice(&data[8..24]);
273 Ok(Self {
274 msg_type: data[0],
275 code: data[1],
276 checksum: u16::from_be_bytes([data[2], data[3]]),
277 max_resp_delay: u16::from_be_bytes([data[4], data[5]]),
278 reserved: u16::from_be_bytes([data[6], data[7]]),
279 multicast_address: addr,
280 })
281 }
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
286pub struct MldAddressRecord {
287 pub record_type: u8,
289 pub aux_data_len: u8,
291 pub multicast_address: [u8; 16],
293 pub source_addresses: Vec<[u8; 16]>,
295}
296
297impl MldAddressRecord {
298 pub fn new(record_type: u8, multicast_address: [u8; 16]) -> Self {
300 Self {
301 record_type,
302 aux_data_len: 0,
303 multicast_address,
304 source_addresses: Vec::new(),
305 }
306 }
307
308 pub fn wire_size(&self) -> usize {
310 4 + 16 + self.source_addresses.len() * 16
312 }
313
314 pub fn serialize_into(&self, buf: &mut Vec<u8>) {
316 buf.push(self.record_type);
317 buf.push(self.aux_data_len);
318 let num_sources = self.source_addresses.len() as u16;
319 buf.push((num_sources >> 8) as u8);
320 buf.push(num_sources as u8);
321 buf.extend_from_slice(&self.multicast_address);
322 for src in &self.source_addresses {
323 buf.extend_from_slice(src);
324 }
325 }
326}
327
328pub fn internet_checksum(data: &[u8]) -> u16 {
334 let mut sum: u32 = 0;
335 let mut i = 0;
336 let len = data.len();
337
338 while i + 1 < len {
340 sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
341 i += 2;
342 }
343
344 if i < len {
346 sum += (data[i] as u32) << 8;
347 }
348
349 while sum >> 16 != 0 {
351 sum = (sum & 0xFFFF) + (sum >> 16);
352 }
353
354 !(sum as u16)
355}
356
357#[derive(Debug, Clone)]
363pub struct GroupState {
364 pub members: u32,
366 pub last_report: u64,
368 pub timer: u64,
370 pub interface_index: u32,
372}
373
374#[derive(Debug, Clone)]
376pub enum OutgoingMessage {
377 Igmp(IgmpMessage),
379 MldReport(Vec<MldAddressRecord>),
381}
382
383#[derive(Default)]
385pub struct MulticastManager {
386 groups_v4: BTreeMap<[u8; 4], GroupState>,
388 groups_v6: BTreeMap<[u8; 16], GroupState>,
390 current_tick: u64,
392 outbox: Vec<OutgoingMessage>,
394}
395
396impl MulticastManager {
397 pub fn new() -> Self {
399 Self::default()
400 }
401
402 pub fn join_group(&mut self, group: MulticastGroup) -> Result<(), MulticastError> {
404 if !is_ipv4_multicast(&group.address) {
405 return Err(MulticastError::InvalidAddress);
406 }
407
408 if let Some(state) = self.groups_v4.get_mut(&group.address) {
409 state.members += 1;
410 return Ok(());
411 }
412
413 let state = GroupState {
414 members: 1,
415 last_report: self.current_tick,
416 timer: UNSOLICITED_REPORT_INTERVAL,
417 interface_index: group.interface_index,
418 };
419 self.groups_v4.insert(group.address, state);
420
421 let report = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, group.address);
423 self.outbox.push(OutgoingMessage::Igmp(report));
424
425 Ok(())
426 }
427
428 pub fn leave_group(&mut self, group: MulticastGroup) -> Result<(), MulticastError> {
431 if !is_ipv4_multicast(&group.address) {
432 return Err(MulticastError::InvalidAddress);
433 }
434
435 let state = self
436 .groups_v4
437 .get_mut(&group.address)
438 .ok_or(MulticastError::GroupNotFound)?;
439
440 state.members = state.members.saturating_sub(1);
441
442 if state.members == 0 {
443 self.groups_v4.remove(&group.address);
444 let leave = IgmpMessage::new(IGMP_LEAVE_GROUP, 0, group.address);
445 self.outbox.push(OutgoingMessage::Igmp(leave));
446 }
447
448 Ok(())
449 }
450
451 pub fn is_member(&self, address: &[u8; 4]) -> bool {
453 self.groups_v4.contains_key(address)
454 }
455
456 pub fn list_groups(&self) -> Vec<MulticastGroup> {
458 self.groups_v4
459 .iter()
460 .map(|(addr, state)| MulticastGroup {
461 address: *addr,
462 interface_index: state.interface_index,
463 })
464 .collect()
465 }
466
467 pub fn join_group_v6(&mut self, group: MulticastGroupV6) -> Result<(), MulticastError> {
469 if !is_ipv6_multicast(&group.address) {
470 return Err(MulticastError::InvalidAddress);
471 }
472
473 if let Some(state) = self.groups_v6.get_mut(&group.address) {
474 state.members += 1;
475 return Ok(());
476 }
477
478 let state = GroupState {
479 members: 1,
480 last_report: self.current_tick,
481 timer: UNSOLICITED_REPORT_INTERVAL,
482 interface_index: group.interface_index,
483 };
484 self.groups_v6.insert(group.address, state);
485
486 let record = MldAddressRecord::new(MLD_RECORD_IS_EX, group.address);
488 self.outbox
489 .push(OutgoingMessage::MldReport(alloc::vec![record]));
490
491 Ok(())
492 }
493
494 pub fn leave_group_v6(&mut self, group: MulticastGroupV6) -> Result<(), MulticastError> {
497 if !is_ipv6_multicast(&group.address) {
498 return Err(MulticastError::InvalidAddress);
499 }
500
501 let state = self
502 .groups_v6
503 .get_mut(&group.address)
504 .ok_or(MulticastError::GroupNotFound)?;
505
506 state.members = state.members.saturating_sub(1);
507
508 if state.members == 0 {
509 self.groups_v6.remove(&group.address);
510 let record = MldAddressRecord::new(MLD_RECORD_TO_IN, group.address);
511 self.outbox
512 .push(OutgoingMessage::MldReport(alloc::vec![record]));
513 }
514
515 Ok(())
516 }
517
518 pub fn is_member_v6(&self, address: &[u8; 16]) -> bool {
520 self.groups_v6.contains_key(address)
521 }
522
523 pub fn handle_query(&mut self, query: &IgmpMessage) {
525 let max_resp = query.max_resp_time as u64 * 100; if query.group_address == [0, 0, 0, 0] {
527 for state in self.groups_v4.values_mut() {
529 state.timer = max_resp.min(state.timer);
530 }
531 } else if let Some(state) = self.groups_v4.get_mut(&query.group_address) {
532 state.timer = max_resp.min(state.timer);
534 }
535 }
536
537 pub fn tick(&mut self) {
540 self.current_tick += 1;
541
542 let mut reports_v4 = Vec::new();
544 for (addr, state) in self.groups_v4.iter_mut() {
545 if state.timer > 0 {
546 state.timer -= 1;
547 }
548 if state.timer == 0 {
549 state.timer = UNSOLICITED_REPORT_INTERVAL;
550 state.last_report = self.current_tick;
551 reports_v4.push(*addr);
552 }
553 }
554 for addr in reports_v4 {
555 let report = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, addr);
556 self.outbox.push(OutgoingMessage::Igmp(report));
557 }
558
559 let mut reports_v6 = Vec::new();
561 for (addr, state) in self.groups_v6.iter_mut() {
562 if state.timer > 0 {
563 state.timer -= 1;
564 }
565 if state.timer == 0 {
566 state.timer = UNSOLICITED_REPORT_INTERVAL;
567 state.last_report = self.current_tick;
568 reports_v6.push(*addr);
569 }
570 }
571 if !reports_v6.is_empty() {
572 let records: Vec<MldAddressRecord> = reports_v6
573 .into_iter()
574 .map(|addr| MldAddressRecord::new(MLD_RECORD_IS_EX, addr))
575 .collect();
576 self.outbox.push(OutgoingMessage::MldReport(records));
577 }
578 }
579
580 pub fn drain_outbox(&mut self) -> Vec<OutgoingMessage> {
582 core::mem::take(&mut self.outbox)
583 }
584
585 pub fn group_count_v4(&self) -> usize {
587 self.groups_v4.len()
588 }
589
590 pub fn group_count_v6(&self) -> usize {
592 self.groups_v6.len()
593 }
594}
595
596static MULTICAST_MANAGER: OnceLock<Mutex<MulticastManager>> = OnceLock::new();
601
602pub fn init() -> Result<(), MulticastError> {
604 MULTICAST_MANAGER
605 .set(Mutex::new(MulticastManager::new()))
606 .map_err(|_| MulticastError::NotInitialized)
607}
608
609pub fn with_manager<R, F: FnOnce(&mut MulticastManager) -> R>(f: F) -> Result<R, MulticastError> {
611 let lock = MULTICAST_MANAGER
612 .get()
613 .ok_or(MulticastError::NotInitialized)?;
614 let mut manager = lock.lock();
615 Ok(f(&mut manager))
616}
617
618#[cfg(test)]
623mod tests {
624 #[allow(unused_imports)]
625 use alloc::vec;
626
627 use super::*;
628
629 #[test]
630 fn test_ipv4_multicast_validation() {
631 assert!(is_ipv4_multicast(&[224, 0, 0, 1]));
632 assert!(is_ipv4_multicast(&[239, 255, 255, 255]));
633 assert!(!is_ipv4_multicast(&[192, 168, 1, 1]));
634 assert!(!is_ipv4_multicast(&[10, 0, 0, 1]));
635 }
636
637 #[test]
638 fn test_ipv6_multicast_validation() {
639 assert!(is_ipv6_multicast(&[
640 0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1
641 ]));
642 assert!(!is_ipv6_multicast(&[
643 0xFE, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1
644 ]));
645 }
646
647 #[test]
648 fn test_multicast_group_new_valid() {
649 let group = MulticastGroup::new([224, 0, 0, 1], 0);
650 assert!(group.is_ok());
651 }
652
653 #[test]
654 fn test_multicast_group_new_invalid() {
655 let group = MulticastGroup::new([192, 168, 1, 1], 0);
656 assert_eq!(group, Err(MulticastError::InvalidAddress));
657 }
658
659 #[test]
660 fn test_igmp_message_serialize_roundtrip() {
661 let msg = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, [224, 0, 0, 1]);
662 let bytes = msg.to_bytes();
663 let parsed = IgmpMessage::from_bytes(&bytes).unwrap();
664 assert_eq!(parsed.msg_type, IGMP_MEMBERSHIP_REPORT);
665 assert_eq!(parsed.group_address, [224, 0, 0, 1]);
666 assert_eq!(parsed.checksum, msg.checksum);
667 }
668
669 #[test]
670 fn test_igmp_checksum_verifies() {
671 let msg = IgmpMessage::new(IGMP_MEMBERSHIP_QUERY, 100, [224, 0, 0, 1]);
672 assert!(msg.verify_checksum());
673 }
674
675 #[test]
676 fn test_igmp_bad_checksum() {
677 let mut msg = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, [224, 0, 0, 1]);
678 msg.checksum = msg.checksum.wrapping_add(1); assert!(!msg.verify_checksum());
680 }
681
682 #[test]
683 fn test_igmp_from_bytes_too_short() {
684 let short = [0u8; 4];
685 assert_eq!(
686 IgmpMessage::from_bytes(&short),
687 Err(MulticastError::MalformedMessage)
688 );
689 }
690
691 #[test]
692 fn test_internet_checksum_rfc_example() {
693 let data = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
695 let cksum = internet_checksum(&data);
696 let mut verify = data.to_vec();
698 verify.push((cksum >> 8) as u8);
699 verify.push(cksum as u8);
700 assert_eq!(internet_checksum(&verify), 0);
701 }
702
703 #[test]
704 fn test_manager_join_leave() {
705 let mut mgr = MulticastManager::new();
706 let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
707
708 assert!(!mgr.is_member(&group.address));
709
710 mgr.join_group(group).unwrap();
711 assert!(mgr.is_member(&group.address));
712 assert_eq!(mgr.group_count_v4(), 1);
713
714 let msgs = mgr.drain_outbox();
716 assert_eq!(msgs.len(), 1);
717
718 mgr.leave_group(group).unwrap();
719 assert!(!mgr.is_member(&group.address));
720 assert_eq!(mgr.group_count_v4(), 0);
721
722 let msgs = mgr.drain_outbox();
724 assert_eq!(msgs.len(), 1);
725 }
726
727 #[test]
728 fn test_manager_multiple_members() {
729 let mut mgr = MulticastManager::new();
730 let group = MulticastGroup::new([224, 0, 0, 5], 0).unwrap();
731
732 mgr.join_group(group).unwrap();
733 mgr.join_group(group).unwrap(); let msgs = mgr.drain_outbox();
737 assert_eq!(msgs.len(), 1);
738
739 mgr.leave_group(group).unwrap();
741 assert!(mgr.is_member(&group.address));
742 assert!(mgr.drain_outbox().is_empty()); mgr.leave_group(group).unwrap();
746 assert!(!mgr.is_member(&group.address));
747 assert_eq!(mgr.drain_outbox().len(), 1); }
749
750 #[test]
751 fn test_manager_list_groups() {
752 let mut mgr = MulticastManager::new();
753 let g1 = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
754 let g2 = MulticastGroup::new([239, 1, 2, 3], 1).unwrap();
755
756 mgr.join_group(g1).unwrap();
757 mgr.join_group(g2).unwrap();
758
759 let groups = mgr.list_groups();
760 assert_eq!(groups.len(), 2);
761 }
762
763 #[test]
764 fn test_manager_leave_unknown_group() {
765 let mut mgr = MulticastManager::new();
766 let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
767 assert_eq!(mgr.leave_group(group), Err(MulticastError::GroupNotFound));
768 }
769
770 #[test]
771 fn test_manager_handle_query() {
772 let mut mgr = MulticastManager::new();
773 let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
774 mgr.join_group(group).unwrap();
775 mgr.drain_outbox(); let query = IgmpMessage::new(IGMP_MEMBERSHIP_QUERY, 100, [0, 0, 0, 0]);
779 mgr.handle_query(&query);
780
781 let state = mgr.groups_v4.get(&group.address).unwrap();
783 assert!(state.timer <= 10000);
784 }
785
786 #[test]
787 fn test_manager_tick_generates_report() {
788 let mut mgr = MulticastManager::new();
789 let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
790 mgr.join_group(group).unwrap();
791 mgr.drain_outbox(); mgr.groups_v4.get_mut(&group.address).unwrap().timer = 1;
795 mgr.tick();
796
797 let msgs = mgr.drain_outbox();
798 assert_eq!(msgs.len(), 1);
799 }
800
801 #[test]
802 fn test_mld_message_roundtrip() {
803 let msg =
804 MldMessage::new_query(1000, [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
805 let bytes = msg.to_bytes();
806 let parsed = MldMessage::from_bytes(&bytes).unwrap();
807 assert_eq!(parsed.msg_type, MLD_QUERY);
808 assert_eq!(parsed.max_resp_delay, 1000);
809 assert_eq!(parsed.multicast_address[0], 0xFF);
810 }
811
812 #[test]
813 fn test_mld_address_record_serialize() {
814 let record = MldAddressRecord::new(
815 MLD_RECORD_IS_EX,
816 [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
817 );
818 let mut buf = Vec::new();
819 record.serialize_into(&mut buf);
820 assert_eq!(buf.len(), record.wire_size());
821 assert_eq!(buf[0], MLD_RECORD_IS_EX);
822 }
823
824 #[test]
825 fn test_manager_v6_join_leave() {
826 let mut mgr = MulticastManager::new();
827 let addr = [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
828 let group = MulticastGroupV6::new(addr, 0).unwrap();
829
830 mgr.join_group_v6(group).unwrap();
831 assert!(mgr.is_member_v6(&addr));
832 assert_eq!(mgr.group_count_v6(), 1);
833
834 let msgs = mgr.drain_outbox();
835 assert_eq!(msgs.len(), 1);
836
837 mgr.leave_group_v6(group).unwrap();
838 assert!(!mgr.is_member_v6(&addr));
839 assert_eq!(mgr.drain_outbox().len(), 1);
840 }
841}