⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
multicast.rs

1//! Multicast group management with IGMP/MLD protocol support
2//!
3//! Provides IPv4 multicast via IGMPv2 and IPv6 multicast via MLDv2,
4//! including group join/leave, periodic report generation, and
5//! query response handling.
6
7#![allow(dead_code)] // Phase 7.5 network stack -- functions called as stack matures
8
9use alloc::{collections::BTreeMap, vec::Vec};
10
11use spin::Mutex;
12
13use crate::sync::once_lock::OnceLock;
14
15// ============================================================================
16// Constants
17// ============================================================================
18
19// IGMPv2 message types
20/// Membership Query
21pub const IGMP_MEMBERSHIP_QUERY: u8 = 0x11;
22/// IGMPv2 Membership Report
23pub const IGMP_MEMBERSHIP_REPORT: u8 = 0x16;
24/// Leave Group
25pub const IGMP_LEAVE_GROUP: u8 = 0x17;
26
27// MLDv2 message types (ICMPv6 types)
28/// Multicast Listener Query
29pub const MLD_QUERY: u8 = 130;
30/// MLDv2 Multicast Listener Report
31pub const MLD_REPORT_V2: u8 = 143;
32
33// MLDv2 multicast address record types
34/// Current-State Record: Include mode
35pub const MLD_RECORD_IS_IN: u8 = 1;
36/// Current-State Record: Exclude mode
37pub const MLD_RECORD_IS_EX: u8 = 2;
38/// Filter-Mode-Change: to Include
39pub const MLD_RECORD_TO_IN: u8 = 3;
40/// Filter-Mode-Change: to Exclude
41pub const MLD_RECORD_TO_EX: u8 = 4;
42/// Source-List-Change: allow new sources
43pub const MLD_RECORD_ALLOW: u8 = 5;
44/// Source-List-Change: block old sources
45pub const MLD_RECORD_BLOCK: u8 = 6;
46
47/// Default unsolicited report interval in ticks
48const UNSOLICITED_REPORT_INTERVAL: u64 = 1000;
49
50// ============================================================================
51// Error Type
52// ============================================================================
53
54/// Errors from multicast operations
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum MulticastError {
57    /// Address is not in the multicast range
58    InvalidAddress,
59    /// Group not found in the membership table
60    GroupNotFound,
61    /// Already a member of this group
62    AlreadyMember,
63    /// Maximum number of groups reached
64    GroupLimitReached,
65    /// Serialization/deserialization failure
66    MalformedMessage,
67    /// Manager not initialized
68    NotInitialized,
69}
70
71// ============================================================================
72// Multicast Group Addresses
73// ============================================================================
74
75/// IPv4 multicast group identifier
76#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77pub struct MulticastGroup {
78    /// IPv4 multicast address (must be in 224.0.0.0/4)
79    pub address: [u8; 4],
80    /// Network interface index
81    pub interface_index: u32,
82}
83
84impl MulticastGroup {
85    /// Create a new multicast group, validating the address range.
86    pub fn new(address: [u8; 4], interface_index: u32) -> Result<Self, MulticastError> {
87        if !is_ipv4_multicast(&address) {
88            return Err(MulticastError::InvalidAddress);
89        }
90        Ok(Self {
91            address,
92            interface_index,
93        })
94    }
95}
96
97/// IPv6 multicast group identifier
98#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
99pub struct MulticastGroupV6 {
100    /// IPv6 multicast address (must be in ff00::/8)
101    pub address: [u8; 16],
102    /// Network interface index
103    pub interface_index: u32,
104}
105
106impl MulticastGroupV6 {
107    /// Create a new IPv6 multicast group, validating the address range.
108    pub fn new(address: [u8; 16], interface_index: u32) -> Result<Self, MulticastError> {
109        if !is_ipv6_multicast(&address) {
110            return Err(MulticastError::InvalidAddress);
111        }
112        Ok(Self {
113            address,
114            interface_index,
115        })
116    }
117}
118
119/// Check if an IPv4 address is in the multicast range (224.0.0.0/4).
120pub fn is_ipv4_multicast(addr: &[u8; 4]) -> bool {
121    addr[0] & 0xF0 == 224
122}
123
124/// Check if an IPv6 address is in the multicast range (ff00::/8).
125pub fn is_ipv6_multicast(addr: &[u8; 16]) -> bool {
126    addr[0] == 0xFF
127}
128
129// ============================================================================
130// IGMPv2 Message
131// ============================================================================
132
133/// IGMPv2 message (8 bytes on the wire)
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct IgmpMessage {
136    /// Message type (0x11 = Query, 0x16 = Report, 0x17 = Leave)
137    pub msg_type: u8,
138    /// Maximum response time (in 1/10 second units)
139    pub max_resp_time: u8,
140    /// Internet checksum over the entire IGMP message
141    pub checksum: u16,
142    /// Group address (0.0.0.0 for general queries)
143    pub group_address: [u8; 4],
144}
145
146impl IgmpMessage {
147    /// Size of a serialized IGMP message in bytes.
148    pub const WIRE_SIZE: usize = 8;
149
150    /// Create a new IGMP message with checksum computed automatically.
151    pub fn new(msg_type: u8, max_resp_time: u8, group_address: [u8; 4]) -> Self {
152        let mut msg = Self {
153            msg_type,
154            max_resp_time,
155            checksum: 0,
156            group_address,
157        };
158        msg.checksum = msg.compute_checksum();
159        msg
160    }
161
162    /// Serialize the message to bytes.
163    pub fn to_bytes(&self) -> [u8; Self::WIRE_SIZE] {
164        let mut buf = [0u8; Self::WIRE_SIZE];
165        buf[0] = self.msg_type;
166        buf[1] = self.max_resp_time;
167        buf[2] = (self.checksum >> 8) as u8;
168        buf[3] = self.checksum as u8;
169        buf[4..8].copy_from_slice(&self.group_address);
170        buf
171    }
172
173    /// Deserialize from bytes.
174    pub fn from_bytes(data: &[u8]) -> Result<Self, MulticastError> {
175        if data.len() < Self::WIRE_SIZE {
176            return Err(MulticastError::MalformedMessage);
177        }
178        Ok(Self {
179            msg_type: data[0],
180            max_resp_time: data[1],
181            checksum: u16::from_be_bytes([data[2], data[3]]),
182            group_address: [data[4], data[5], data[6], data[7]],
183        })
184    }
185
186    /// Compute the Internet checksum over the IGMP message.
187    pub fn compute_checksum(&self) -> u16 {
188        let mut bytes = self.to_bytes();
189        // Zero out the checksum field before computing
190        bytes[2] = 0;
191        bytes[3] = 0;
192        internet_checksum(&bytes)
193    }
194
195    /// Verify the message checksum.
196    pub fn verify_checksum(&self) -> bool {
197        let bytes = self.to_bytes();
198        internet_checksum(&bytes) == 0
199    }
200}
201
202// ============================================================================
203// MLDv2 Message
204// ============================================================================
205
206/// MLDv2 message header
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub struct MldMessage {
209    /// Message type (130 = Query, 143 = Report)
210    pub msg_type: u8,
211    /// Code (subtype, typically 0)
212    pub code: u8,
213    /// Checksum (computed over pseudo-header + message)
214    pub checksum: u16,
215    /// Maximum response delay (queries) or reserved (reports)
216    pub max_resp_delay: u16,
217    /// Reserved field
218    pub reserved: u16,
219    /// Multicast address (queries) or zero (reports)
220    pub multicast_address: [u8; 16],
221}
222
223impl MldMessage {
224    /// Minimum header size in bytes.
225    pub const HEADER_SIZE: usize = 24;
226
227    /// Create a new MLD query message.
228    pub fn new_query(max_resp_delay: u16, multicast_address: [u8; 16]) -> Self {
229        Self {
230            msg_type: MLD_QUERY,
231            code: 0,
232            checksum: 0,
233            max_resp_delay,
234            reserved: 0,
235            multicast_address,
236        }
237    }
238
239    /// Create a new MLDv2 report message.
240    pub fn new_report() -> Self {
241        Self {
242            msg_type: MLD_REPORT_V2,
243            code: 0,
244            checksum: 0,
245            max_resp_delay: 0,
246            reserved: 0,
247            multicast_address: [0; 16],
248        }
249    }
250
251    /// Serialize the message header to bytes.
252    pub fn to_bytes(&self) -> [u8; Self::HEADER_SIZE] {
253        let mut buf = [0u8; Self::HEADER_SIZE];
254        buf[0] = self.msg_type;
255        buf[1] = self.code;
256        buf[2] = (self.checksum >> 8) as u8;
257        buf[3] = self.checksum as u8;
258        buf[4] = (self.max_resp_delay >> 8) as u8;
259        buf[5] = self.max_resp_delay as u8;
260        buf[6] = (self.reserved >> 8) as u8;
261        buf[7] = self.reserved as u8;
262        buf[8..24].copy_from_slice(&self.multicast_address);
263        buf
264    }
265
266    /// Deserialize from bytes.
267    pub fn from_bytes(data: &[u8]) -> Result<Self, MulticastError> {
268        if data.len() < Self::HEADER_SIZE {
269            return Err(MulticastError::MalformedMessage);
270        }
271        let mut addr = [0u8; 16];
272        addr.copy_from_slice(&data[8..24]);
273        Ok(Self {
274            msg_type: data[0],
275            code: data[1],
276            checksum: u16::from_be_bytes([data[2], data[3]]),
277            max_resp_delay: u16::from_be_bytes([data[4], data[5]]),
278            reserved: u16::from_be_bytes([data[6], data[7]]),
279            multicast_address: addr,
280        })
281    }
282}
283
284/// MLDv2 multicast address record
285#[derive(Debug, Clone, PartialEq, Eq)]
286pub struct MldAddressRecord {
287    /// Record type (IS_IN, IS_EX, TO_IN, TO_EX, ALLOW, BLOCK)
288    pub record_type: u8,
289    /// Auxiliary data length (in 32-bit words)
290    pub aux_data_len: u8,
291    /// Multicast address
292    pub multicast_address: [u8; 16],
293    /// Source addresses
294    pub source_addresses: Vec<[u8; 16]>,
295}
296
297impl MldAddressRecord {
298    /// Create a new address record for a group.
299    pub fn new(record_type: u8, multicast_address: [u8; 16]) -> Self {
300        Self {
301            record_type,
302            aux_data_len: 0,
303            multicast_address,
304            source_addresses: Vec::new(),
305        }
306    }
307
308    /// Serialized size in bytes.
309    pub fn wire_size(&self) -> usize {
310        // 4 bytes header + 16 bytes mcast addr + 16 * num_sources
311        4 + 16 + self.source_addresses.len() * 16
312    }
313
314    /// Serialize to bytes (appended to the provided buffer).
315    pub fn serialize_into(&self, buf: &mut Vec<u8>) {
316        buf.push(self.record_type);
317        buf.push(self.aux_data_len);
318        let num_sources = self.source_addresses.len() as u16;
319        buf.push((num_sources >> 8) as u8);
320        buf.push(num_sources as u8);
321        buf.extend_from_slice(&self.multicast_address);
322        for src in &self.source_addresses {
323            buf.extend_from_slice(src);
324        }
325    }
326}
327
328// ============================================================================
329// Checksum
330// ============================================================================
331
332/// Compute the Internet checksum (RFC 1071) over a byte slice.
333pub fn internet_checksum(data: &[u8]) -> u16 {
334    let mut sum: u32 = 0;
335    let mut i = 0;
336    let len = data.len();
337
338    // Sum 16-bit words
339    while i + 1 < len {
340        sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
341        i += 2;
342    }
343
344    // Handle odd trailing byte
345    if i < len {
346        sum += (data[i] as u32) << 8;
347    }
348
349    // Fold 32-bit sum to 16 bits
350    while sum >> 16 != 0 {
351        sum = (sum & 0xFFFF) + (sum >> 16);
352    }
353
354    !(sum as u16)
355}
356
357// ============================================================================
358// Group State and Manager
359// ============================================================================
360
361/// State of a joined multicast group
362#[derive(Debug, Clone)]
363pub struct GroupState {
364    /// Number of local members (sockets) in this group
365    pub members: u32,
366    /// Tick count of last report sent
367    pub last_report: u64,
368    /// Remaining ticks until next unsolicited report
369    pub timer: u64,
370    /// Interface index the group is joined on
371    pub interface_index: u32,
372}
373
374/// Outgoing message produced by the manager (for the network layer to send)
375#[derive(Debug, Clone)]
376pub enum OutgoingMessage {
377    /// Send an IGMPv2 message for an IPv4 group
378    Igmp(IgmpMessage),
379    /// Send an MLDv2 report with address records for IPv6 groups
380    MldReport(Vec<MldAddressRecord>),
381}
382
383/// Manages multicast group memberships for IPv4 and IPv6.
384#[derive(Default)]
385pub struct MulticastManager {
386    /// IPv4 groups keyed by group address
387    groups_v4: BTreeMap<[u8; 4], GroupState>,
388    /// IPv6 groups keyed by group address
389    groups_v6: BTreeMap<[u8; 16], GroupState>,
390    /// Current tick counter
391    current_tick: u64,
392    /// Pending outgoing messages
393    outbox: Vec<OutgoingMessage>,
394}
395
396impl MulticastManager {
397    /// Create a new, empty multicast manager.
398    pub fn new() -> Self {
399        Self::default()
400    }
401
402    /// Join an IPv4 multicast group. Sends an immediate IGMP report.
403    pub fn join_group(&mut self, group: MulticastGroup) -> Result<(), MulticastError> {
404        if !is_ipv4_multicast(&group.address) {
405            return Err(MulticastError::InvalidAddress);
406        }
407
408        if let Some(state) = self.groups_v4.get_mut(&group.address) {
409            state.members += 1;
410            return Ok(());
411        }
412
413        let state = GroupState {
414            members: 1,
415            last_report: self.current_tick,
416            timer: UNSOLICITED_REPORT_INTERVAL,
417            interface_index: group.interface_index,
418        };
419        self.groups_v4.insert(group.address, state);
420
421        // Send immediate membership report
422        let report = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, group.address);
423        self.outbox.push(OutgoingMessage::Igmp(report));
424
425        Ok(())
426    }
427
428    /// Leave an IPv4 multicast group. Sends a leave message when last member
429    /// leaves.
430    pub fn leave_group(&mut self, group: MulticastGroup) -> Result<(), MulticastError> {
431        if !is_ipv4_multicast(&group.address) {
432            return Err(MulticastError::InvalidAddress);
433        }
434
435        let state = self
436            .groups_v4
437            .get_mut(&group.address)
438            .ok_or(MulticastError::GroupNotFound)?;
439
440        state.members = state.members.saturating_sub(1);
441
442        if state.members == 0 {
443            self.groups_v4.remove(&group.address);
444            let leave = IgmpMessage::new(IGMP_LEAVE_GROUP, 0, group.address);
445            self.outbox.push(OutgoingMessage::Igmp(leave));
446        }
447
448        Ok(())
449    }
450
451    /// Check whether the given IPv4 address is a currently-joined group.
452    pub fn is_member(&self, address: &[u8; 4]) -> bool {
453        self.groups_v4.contains_key(address)
454    }
455
456    /// List all currently-joined IPv4 multicast groups.
457    pub fn list_groups(&self) -> Vec<MulticastGroup> {
458        self.groups_v4
459            .iter()
460            .map(|(addr, state)| MulticastGroup {
461                address: *addr,
462                interface_index: state.interface_index,
463            })
464            .collect()
465    }
466
467    /// Join an IPv6 multicast group. Sends an immediate MLDv2 report.
468    pub fn join_group_v6(&mut self, group: MulticastGroupV6) -> Result<(), MulticastError> {
469        if !is_ipv6_multicast(&group.address) {
470            return Err(MulticastError::InvalidAddress);
471        }
472
473        if let Some(state) = self.groups_v6.get_mut(&group.address) {
474            state.members += 1;
475            return Ok(());
476        }
477
478        let state = GroupState {
479            members: 1,
480            last_report: self.current_tick,
481            timer: UNSOLICITED_REPORT_INTERVAL,
482            interface_index: group.interface_index,
483        };
484        self.groups_v6.insert(group.address, state);
485
486        // Send immediate MLDv2 IS_EX report (no sources = join all)
487        let record = MldAddressRecord::new(MLD_RECORD_IS_EX, group.address);
488        self.outbox
489            .push(OutgoingMessage::MldReport(alloc::vec![record]));
490
491        Ok(())
492    }
493
494    /// Leave an IPv6 multicast group. Sends a TO_IN record when last member
495    /// leaves.
496    pub fn leave_group_v6(&mut self, group: MulticastGroupV6) -> Result<(), MulticastError> {
497        if !is_ipv6_multicast(&group.address) {
498            return Err(MulticastError::InvalidAddress);
499        }
500
501        let state = self
502            .groups_v6
503            .get_mut(&group.address)
504            .ok_or(MulticastError::GroupNotFound)?;
505
506        state.members = state.members.saturating_sub(1);
507
508        if state.members == 0 {
509            self.groups_v6.remove(&group.address);
510            let record = MldAddressRecord::new(MLD_RECORD_TO_IN, group.address);
511            self.outbox
512                .push(OutgoingMessage::MldReport(alloc::vec![record]));
513        }
514
515        Ok(())
516    }
517
518    /// Check whether the given IPv6 address is a currently-joined group.
519    pub fn is_member_v6(&self, address: &[u8; 16]) -> bool {
520        self.groups_v6.contains_key(address)
521    }
522
523    /// Handle an incoming IGMP query by resetting report timers.
524    pub fn handle_query(&mut self, query: &IgmpMessage) {
525        let max_resp = query.max_resp_time as u64 * 100; // Convert to ticks (~100ms units)
526        if query.group_address == [0, 0, 0, 0] {
527            // General query: reset all group timers
528            for state in self.groups_v4.values_mut() {
529                state.timer = max_resp.min(state.timer);
530            }
531        } else if let Some(state) = self.groups_v4.get_mut(&query.group_address) {
532            // Group-specific query
533            state.timer = max_resp.min(state.timer);
534        }
535    }
536
537    /// Advance the tick counter and generate unsolicited reports for groups
538    /// whose timers have expired.
539    pub fn tick(&mut self) {
540        self.current_tick += 1;
541
542        // Check IPv4 group timers
543        let mut reports_v4 = Vec::new();
544        for (addr, state) in self.groups_v4.iter_mut() {
545            if state.timer > 0 {
546                state.timer -= 1;
547            }
548            if state.timer == 0 {
549                state.timer = UNSOLICITED_REPORT_INTERVAL;
550                state.last_report = self.current_tick;
551                reports_v4.push(*addr);
552            }
553        }
554        for addr in reports_v4 {
555            let report = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, addr);
556            self.outbox.push(OutgoingMessage::Igmp(report));
557        }
558
559        // Check IPv6 group timers
560        let mut reports_v6 = Vec::new();
561        for (addr, state) in self.groups_v6.iter_mut() {
562            if state.timer > 0 {
563                state.timer -= 1;
564            }
565            if state.timer == 0 {
566                state.timer = UNSOLICITED_REPORT_INTERVAL;
567                state.last_report = self.current_tick;
568                reports_v6.push(*addr);
569            }
570        }
571        if !reports_v6.is_empty() {
572            let records: Vec<MldAddressRecord> = reports_v6
573                .into_iter()
574                .map(|addr| MldAddressRecord::new(MLD_RECORD_IS_EX, addr))
575                .collect();
576            self.outbox.push(OutgoingMessage::MldReport(records));
577        }
578    }
579
580    /// Drain all pending outgoing messages.
581    pub fn drain_outbox(&mut self) -> Vec<OutgoingMessage> {
582        core::mem::take(&mut self.outbox)
583    }
584
585    /// Return count of joined IPv4 groups.
586    pub fn group_count_v4(&self) -> usize {
587        self.groups_v4.len()
588    }
589
590    /// Return count of joined IPv6 groups.
591    pub fn group_count_v6(&self) -> usize {
592        self.groups_v6.len()
593    }
594}
595
596// ============================================================================
597// Global Manager
598// ============================================================================
599
600static MULTICAST_MANAGER: OnceLock<Mutex<MulticastManager>> = OnceLock::new();
601
602/// Initialize the global multicast manager.
603pub fn init() -> Result<(), MulticastError> {
604    MULTICAST_MANAGER
605        .set(Mutex::new(MulticastManager::new()))
606        .map_err(|_| MulticastError::NotInitialized)
607}
608
609/// Access the global multicast manager.
610pub fn with_manager<R, F: FnOnce(&mut MulticastManager) -> R>(f: F) -> Result<R, MulticastError> {
611    let lock = MULTICAST_MANAGER
612        .get()
613        .ok_or(MulticastError::NotInitialized)?;
614    let mut manager = lock.lock();
615    Ok(f(&mut manager))
616}
617
618// ============================================================================
619// Tests
620// ============================================================================
621
622#[cfg(test)]
623mod tests {
624    #[allow(unused_imports)]
625    use alloc::vec;
626
627    use super::*;
628
629    #[test]
630    fn test_ipv4_multicast_validation() {
631        assert!(is_ipv4_multicast(&[224, 0, 0, 1]));
632        assert!(is_ipv4_multicast(&[239, 255, 255, 255]));
633        assert!(!is_ipv4_multicast(&[192, 168, 1, 1]));
634        assert!(!is_ipv4_multicast(&[10, 0, 0, 1]));
635    }
636
637    #[test]
638    fn test_ipv6_multicast_validation() {
639        assert!(is_ipv6_multicast(&[
640            0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1
641        ]));
642        assert!(!is_ipv6_multicast(&[
643            0xFE, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1
644        ]));
645    }
646
647    #[test]
648    fn test_multicast_group_new_valid() {
649        let group = MulticastGroup::new([224, 0, 0, 1], 0);
650        assert!(group.is_ok());
651    }
652
653    #[test]
654    fn test_multicast_group_new_invalid() {
655        let group = MulticastGroup::new([192, 168, 1, 1], 0);
656        assert_eq!(group, Err(MulticastError::InvalidAddress));
657    }
658
659    #[test]
660    fn test_igmp_message_serialize_roundtrip() {
661        let msg = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, [224, 0, 0, 1]);
662        let bytes = msg.to_bytes();
663        let parsed = IgmpMessage::from_bytes(&bytes).unwrap();
664        assert_eq!(parsed.msg_type, IGMP_MEMBERSHIP_REPORT);
665        assert_eq!(parsed.group_address, [224, 0, 0, 1]);
666        assert_eq!(parsed.checksum, msg.checksum);
667    }
668
669    #[test]
670    fn test_igmp_checksum_verifies() {
671        let msg = IgmpMessage::new(IGMP_MEMBERSHIP_QUERY, 100, [224, 0, 0, 1]);
672        assert!(msg.verify_checksum());
673    }
674
675    #[test]
676    fn test_igmp_bad_checksum() {
677        let mut msg = IgmpMessage::new(IGMP_MEMBERSHIP_REPORT, 0, [224, 0, 0, 1]);
678        msg.checksum = msg.checksum.wrapping_add(1); // Corrupt it
679        assert!(!msg.verify_checksum());
680    }
681
682    #[test]
683    fn test_igmp_from_bytes_too_short() {
684        let short = [0u8; 4];
685        assert_eq!(
686            IgmpMessage::from_bytes(&short),
687            Err(MulticastError::MalformedMessage)
688        );
689    }
690
691    #[test]
692    fn test_internet_checksum_rfc_example() {
693        // RFC 1071 example: 0x0001 + 0xf203 + ... (simplified test)
694        let data = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
695        let cksum = internet_checksum(&data);
696        // Verify: applying checksum to data+checksum should yield 0
697        let mut verify = data.to_vec();
698        verify.push((cksum >> 8) as u8);
699        verify.push(cksum as u8);
700        assert_eq!(internet_checksum(&verify), 0);
701    }
702
703    #[test]
704    fn test_manager_join_leave() {
705        let mut mgr = MulticastManager::new();
706        let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
707
708        assert!(!mgr.is_member(&group.address));
709
710        mgr.join_group(group).unwrap();
711        assert!(mgr.is_member(&group.address));
712        assert_eq!(mgr.group_count_v4(), 1);
713
714        // Should generate an IGMP report
715        let msgs = mgr.drain_outbox();
716        assert_eq!(msgs.len(), 1);
717
718        mgr.leave_group(group).unwrap();
719        assert!(!mgr.is_member(&group.address));
720        assert_eq!(mgr.group_count_v4(), 0);
721
722        // Should generate an IGMP leave
723        let msgs = mgr.drain_outbox();
724        assert_eq!(msgs.len(), 1);
725    }
726
727    #[test]
728    fn test_manager_multiple_members() {
729        let mut mgr = MulticastManager::new();
730        let group = MulticastGroup::new([224, 0, 0, 5], 0).unwrap();
731
732        mgr.join_group(group).unwrap();
733        mgr.join_group(group).unwrap(); // Second join increments member count
734
735        // One report for the initial join only
736        let msgs = mgr.drain_outbox();
737        assert_eq!(msgs.len(), 1);
738
739        // First leave decrements count but group remains
740        mgr.leave_group(group).unwrap();
741        assert!(mgr.is_member(&group.address));
742        assert!(mgr.drain_outbox().is_empty()); // No leave message yet
743
744        // Second leave removes the group
745        mgr.leave_group(group).unwrap();
746        assert!(!mgr.is_member(&group.address));
747        assert_eq!(mgr.drain_outbox().len(), 1); // Leave message now
748    }
749
750    #[test]
751    fn test_manager_list_groups() {
752        let mut mgr = MulticastManager::new();
753        let g1 = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
754        let g2 = MulticastGroup::new([239, 1, 2, 3], 1).unwrap();
755
756        mgr.join_group(g1).unwrap();
757        mgr.join_group(g2).unwrap();
758
759        let groups = mgr.list_groups();
760        assert_eq!(groups.len(), 2);
761    }
762
763    #[test]
764    fn test_manager_leave_unknown_group() {
765        let mut mgr = MulticastManager::new();
766        let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
767        assert_eq!(mgr.leave_group(group), Err(MulticastError::GroupNotFound));
768    }
769
770    #[test]
771    fn test_manager_handle_query() {
772        let mut mgr = MulticastManager::new();
773        let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
774        mgr.join_group(group).unwrap();
775        mgr.drain_outbox(); // Clear join report
776
777        // Send a general query with 10 second max resp time
778        let query = IgmpMessage::new(IGMP_MEMBERSHIP_QUERY, 100, [0, 0, 0, 0]);
779        mgr.handle_query(&query);
780
781        // Timer should be capped at max_resp_time * 100 = 10000
782        let state = mgr.groups_v4.get(&group.address).unwrap();
783        assert!(state.timer <= 10000);
784    }
785
786    #[test]
787    fn test_manager_tick_generates_report() {
788        let mut mgr = MulticastManager::new();
789        let group = MulticastGroup::new([224, 0, 0, 1], 0).unwrap();
790        mgr.join_group(group).unwrap();
791        mgr.drain_outbox(); // Clear join report
792
793        // Set timer to 1 so next tick triggers a report
794        mgr.groups_v4.get_mut(&group.address).unwrap().timer = 1;
795        mgr.tick();
796
797        let msgs = mgr.drain_outbox();
798        assert_eq!(msgs.len(), 1);
799    }
800
801    #[test]
802    fn test_mld_message_roundtrip() {
803        let msg =
804            MldMessage::new_query(1000, [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
805        let bytes = msg.to_bytes();
806        let parsed = MldMessage::from_bytes(&bytes).unwrap();
807        assert_eq!(parsed.msg_type, MLD_QUERY);
808        assert_eq!(parsed.max_resp_delay, 1000);
809        assert_eq!(parsed.multicast_address[0], 0xFF);
810    }
811
812    #[test]
813    fn test_mld_address_record_serialize() {
814        let record = MldAddressRecord::new(
815            MLD_RECORD_IS_EX,
816            [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
817        );
818        let mut buf = Vec::new();
819        record.serialize_into(&mut buf);
820        assert_eq!(buf.len(), record.wire_size());
821        assert_eq!(buf[0], MLD_RECORD_IS_EX);
822    }
823
824    #[test]
825    fn test_manager_v6_join_leave() {
826        let mut mgr = MulticastManager::new();
827        let addr = [0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
828        let group = MulticastGroupV6::new(addr, 0).unwrap();
829
830        mgr.join_group_v6(group).unwrap();
831        assert!(mgr.is_member_v6(&addr));
832        assert_eq!(mgr.group_count_v6(), 1);
833
834        let msgs = mgr.drain_outbox();
835        assert_eq!(msgs.len(), 1);
836
837        mgr.leave_group_v6(group).unwrap();
838        assert!(!mgr.is_member_v6(&addr));
839        assert_eq!(mgr.drain_outbox().len(), 1);
840    }
841}