⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
mdns.rs

1//! mDNS/DNS-SD: Multicast DNS and Service Discovery for VeridianOS
2//!
3//! Implements RFC 6762 (Multicast DNS) and RFC 6763 (DNS-Based Service
4//! Discovery) for zero-configuration networking on the `.local` domain.
5//!
6//! Features:
7//! - mDNS query/response on 224.0.0.251:5353 (IPv4) / [ff02::fb]:5353 (IPv6)
8//! - One-shot and continuous query modes
9//! - Known-answer suppression
10//! - Conflict resolution: probe (3x, 250ms) then announce (2x, 1s)
11//! - DNS-SD service registration, browsing, and deregistration
12//! - PTR/SRV/TXT record handling for service instances
13//! - TTL-based cache with expiry
14//! - Goodbye packets (TTL=0) on shutdown
15
16#![allow(dead_code)]
17
18#[cfg(feature = "alloc")]
19use alloc::{collections::BTreeMap, string::String, vec::Vec};
20
21// ============================================================================
22// Constants
23// ============================================================================
24
25/// mDNS multicast IPv4 address: 224.0.0.251
26pub const MDNS_IPV4_ADDR: [u8; 4] = [224, 0, 0, 251];
27
28/// mDNS multicast IPv6 address: ff02::fb
29pub const MDNS_IPV6_ADDR: [u8; 16] = [0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfb];
30
31/// mDNS port
32pub const MDNS_PORT: u16 = 5353;
33
34/// Default TTL for unique records (host addresses, SRV)
35pub const TTL_UNIQUE: u32 = 120;
36
37/// Default TTL for shared records (PTR, service types)
38pub const TTL_SHARED: u32 = 4500;
39
40/// Probe interval in milliseconds
41pub const PROBE_INTERVAL_MS: u64 = 250;
42
43/// Number of probe attempts before claiming a name
44pub const PROBE_COUNT: u8 = 3;
45
46/// Announce interval in milliseconds
47pub const ANNOUNCE_INTERVAL_MS: u64 = 1000;
48
49/// Number of announcements after claiming a name
50pub const ANNOUNCE_COUNT: u8 = 2;
51
52/// Maximum mDNS message size (same as DNS over UDP)
53const MAX_MDNS_MSG_SIZE: usize = 512;
54
55/// Maximum label length (per DNS spec)
56const MAX_LABEL_LEN: usize = 63;
57
58/// Maximum domain name length
59const MAX_NAME_LEN: usize = 255;
60
61/// Maximum cached entries
62const MAX_CACHE_ENTRIES: usize = 512;
63
64/// Maximum registered services
65const MAX_SERVICES: usize = 64;
66
67/// Maximum TXT record pairs
68const MAX_TXT_PAIRS: usize = 16;
69
70/// QU (unicast response) bit in class field
71const QU_BIT: u16 = 0x8000;
72
73/// Cache-flush bit in class field (for responses)
74const CACHE_FLUSH_BIT: u16 = 0x8000;
75
76/// `.local` suffix
77pub const LOCAL_SUFFIX: &str = ".local";
78
79// ============================================================================
80// DNS Record Types (local definitions to avoid coupling)
81// ============================================================================
82
83/// DNS record type codes used by mDNS/DNS-SD
84#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
85#[repr(u16)]
86pub enum MdnsRecordType {
87    /// IPv4 address
88    A = 1,
89    /// Domain name pointer (reverse lookup / service enumeration)
90    PTR = 12,
91    /// Text record (key=value metadata)
92    TXT = 16,
93    /// IPv6 address
94    AAAA = 28,
95    /// Service locator (priority, weight, port, target)
96    SRV = 33,
97    /// Any type (query wildcard)
98    ANY = 255,
99    /// Unknown / unsupported
100    Unknown = 0,
101}
102
103impl MdnsRecordType {
104    pub fn from_u16(val: u16) -> Self {
105        match val {
106            1 => Self::A,
107            12 => Self::PTR,
108            16 => Self::TXT,
109            28 => Self::AAAA,
110            33 => Self::SRV,
111            255 => Self::ANY,
112            _ => Self::Unknown,
113        }
114    }
115
116    pub fn to_u16(self) -> u16 {
117        self as u16
118    }
119}
120
121/// DNS class (IN = Internet)
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123#[repr(u16)]
124pub enum MdnsClass {
125    /// Internet class
126    IN = 1,
127    /// Any class (query wildcard)
128    ANY = 255,
129}
130
131impl MdnsClass {
132    pub fn from_u16(val: u16) -> Self {
133        match val & !QU_BIT {
134            1 => Self::IN,
135            255 => Self::ANY,
136            _ => Self::IN,
137        }
138    }
139}
140
141// ============================================================================
142// Error Type
143// ============================================================================
144
145/// Errors from mDNS operations
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum MdnsError {
148    /// Message too short to parse
149    MessageTooShort,
150    /// Invalid label in domain name
151    InvalidLabel,
152    /// Name exceeds maximum length
153    NameTooLong,
154    /// Buffer too small for serialization
155    BufferTooSmall,
156    /// Service limit reached
157    TooManyServices,
158    /// Name conflict detected during probing
159    NameConflict,
160    /// Cache is full
161    CacheFull,
162    /// Invalid service type format
163    InvalidServiceType,
164    /// TXT record too large
165    TxtTooLarge,
166    /// Record not found
167    NotFound,
168    /// Invalid message format
169    InvalidFormat,
170}
171
172// ============================================================================
173// Record data types
174// ============================================================================
175
176/// SRV record data (RFC 2782)
177#[derive(Debug, Clone, PartialEq, Eq)]
178#[cfg(feature = "alloc")]
179pub struct SrvRecord {
180    /// Priority (lower = preferred)
181    pub priority: u16,
182    /// Weight for load balancing among same-priority targets
183    pub weight: u16,
184    /// Port number
185    pub port: u16,
186    /// Target hostname
187    pub target: String,
188}
189
190/// TXT record: collection of key=value pairs
191#[derive(Debug, Clone, PartialEq, Eq)]
192#[cfg(feature = "alloc")]
193pub struct TxtRecord {
194    /// Key-value pairs
195    pub entries: Vec<TxtEntry>,
196}
197
198/// Single TXT record entry
199#[derive(Debug, Clone, PartialEq, Eq)]
200#[cfg(feature = "alloc")]
201pub struct TxtEntry {
202    /// Key (before '=')
203    pub key: String,
204    /// Value (after '='), empty if boolean key
205    pub value: String,
206}
207
208#[cfg(feature = "alloc")]
209impl Default for TxtRecord {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215#[cfg(feature = "alloc")]
216impl TxtRecord {
217    /// Create an empty TXT record
218    pub fn new() -> Self {
219        Self {
220            entries: Vec::new(),
221        }
222    }
223
224    /// Add a key=value pair
225    pub fn add(&mut self, key: &str, value: &str) -> Result<(), MdnsError> {
226        if self.entries.len() >= MAX_TXT_PAIRS {
227            return Err(MdnsError::TxtTooLarge);
228        }
229        // Each entry is length-prefixed; key=value must fit in 255 bytes
230        if key.len() + 1 + value.len() > 255 {
231            return Err(MdnsError::TxtTooLarge);
232        }
233        self.entries.push(TxtEntry {
234            key: String::from(key),
235            value: String::from(value),
236        });
237        Ok(())
238    }
239
240    /// Encode TXT record to wire format (length-prefixed strings)
241    pub fn encode(&self) -> Vec<u8> {
242        let mut buf = Vec::new();
243        for entry in &self.entries {
244            let s = if entry.value.is_empty() {
245                entry.key.clone()
246            } else {
247                let mut s = entry.key.clone();
248                s.push('=');
249                s.push_str(&entry.value);
250                s
251            };
252            let len = s.len();
253            if len <= 255 {
254                buf.push(len as u8);
255                buf.extend_from_slice(s.as_bytes());
256            }
257        }
258        // RFC 6763: empty TXT record must contain single zero byte
259        if buf.is_empty() {
260            buf.push(0);
261        }
262        buf
263    }
264
265    /// Decode TXT record from wire format
266    pub fn decode(data: &[u8]) -> Result<Self, MdnsError> {
267        let mut entries = Vec::new();
268        let mut pos = 0;
269        while pos < data.len() {
270            let len = data[pos] as usize;
271            pos += 1;
272            if len == 0 {
273                continue;
274            }
275            if pos + len > data.len() {
276                return Err(MdnsError::InvalidFormat);
277            }
278            let s = core::str::from_utf8(&data[pos..pos + len])
279                .map_err(|_| MdnsError::InvalidFormat)?;
280            pos += len;
281            if let Some(eq_pos) = s.find('=') {
282                entries.push(TxtEntry {
283                    key: String::from(&s[..eq_pos]),
284                    value: String::from(&s[eq_pos + 1..]),
285                });
286            } else {
287                entries.push(TxtEntry {
288                    key: String::from(s),
289                    value: String::new(),
290                });
291            }
292        }
293        Ok(Self { entries })
294    }
295
296    /// Look up a value by key
297    pub fn get(&self, key: &str) -> Option<&str> {
298        for entry in &self.entries {
299            if entry.key == key {
300                return Some(&entry.value);
301            }
302        }
303        None
304    }
305}
306
307// ============================================================================
308// mDNS Resource Record
309// ============================================================================
310
311/// An mDNS resource record
312#[derive(Debug, Clone, PartialEq, Eq)]
313#[cfg(feature = "alloc")]
314pub struct MdnsRecord {
315    /// Fully-qualified domain name
316    pub name: String,
317    /// Record type
318    pub rtype: MdnsRecordType,
319    /// Cache-flush (unique record)
320    pub cache_flush: bool,
321    /// Time-to-live in seconds
322    pub ttl: u32,
323    /// Record data (wire format)
324    pub rdata: Vec<u8>,
325}
326
327/// An mDNS question
328#[derive(Debug, Clone, PartialEq, Eq)]
329#[cfg(feature = "alloc")]
330pub struct MdnsQuestion {
331    /// Queried name
332    pub name: String,
333    /// Query type
334    pub qtype: MdnsRecordType,
335    /// Unicast response requested (QU bit)
336    pub unicast: bool,
337}
338
339// ============================================================================
340// Service Type Parsing
341// ============================================================================
342
343/// A parsed DNS-SD service type
344#[derive(Debug, Clone, PartialEq, Eq)]
345#[cfg(feature = "alloc")]
346pub struct ServiceType {
347    /// Service name (e.g., "http")
348    pub service: String,
349    /// Protocol ("tcp" or "udp")
350    pub protocol: String,
351    /// Domain (e.g., "local")
352    pub domain: String,
353}
354
355#[cfg(feature = "alloc")]
356impl ServiceType {
357    /// Parse a service type string like "_http._tcp.local"
358    pub fn parse(s: &str) -> Result<Self, MdnsError> {
359        let s = s.trim_end_matches('.');
360        let parts: Vec<&str> = s.split('.').collect();
361        if parts.len() < 3 {
362            return Err(MdnsError::InvalidServiceType);
363        }
364        let service_part = parts[0];
365        let proto_part = parts[1];
366        let mut domain = String::new();
367        for (i, part) in parts[2..].iter().enumerate() {
368            if i > 0 {
369                domain.push('.');
370            }
371            domain.push_str(part);
372        }
373
374        // Service must start with '_'
375        if !service_part.starts_with('_') || service_part.len() < 2 {
376            return Err(MdnsError::InvalidServiceType);
377        }
378        // Protocol must be _tcp or _udp
379        if proto_part != "_tcp" && proto_part != "_udp" {
380            return Err(MdnsError::InvalidServiceType);
381        }
382
383        Ok(Self {
384            service: String::from(&service_part[1..]),
385            protocol: String::from(&proto_part[1..]),
386            domain,
387        })
388    }
389
390    /// Format as DNS-SD service type string
391    pub fn to_service_string(&self) -> String {
392        let mut s = String::from("_");
393        s.push_str(&self.service);
394        s.push_str("._");
395        s.push_str(&self.protocol);
396        s.push('.');
397        s.push_str(&self.domain);
398        s
399    }
400}
401
402// ============================================================================
403// DNS-SD Service Instance
404// ============================================================================
405
406/// A registered DNS-SD service instance
407#[derive(Debug, Clone, PartialEq, Eq)]
408#[cfg(feature = "alloc")]
409pub struct ServiceInstance {
410    /// Human-readable instance name (e.g., "My Web Server")
411    pub instance_name: String,
412    /// Service type (e.g., "_http._tcp")
413    pub service_type: ServiceType,
414    /// Port number
415    pub port: u16,
416    /// Target hostname (defaults to local hostname)
417    pub target: String,
418    /// Priority for SRV record
419    pub priority: u16,
420    /// Weight for SRV record
421    pub weight: u16,
422    /// TXT record metadata
423    pub txt: TxtRecord,
424}
425
426#[cfg(feature = "alloc")]
427impl ServiceInstance {
428    /// Construct the full service instance name
429    /// e.g., "My Web Server._http._tcp.local"
430    pub fn full_name(&self) -> String {
431        let mut name = self.instance_name.clone();
432        name.push('.');
433        name.push_str(&self.service_type.to_service_string());
434        name
435    }
436
437    /// Build the PTR record name for this service type
438    pub fn ptr_name(&self) -> String {
439        self.service_type.to_service_string()
440    }
441
442    /// Build an SRV record for this instance
443    pub fn to_srv(&self) -> SrvRecord {
444        SrvRecord {
445            priority: self.priority,
446            weight: self.weight,
447            port: self.port,
448            target: self.target.clone(),
449        }
450    }
451
452    /// Encode SRV rdata to wire format: priority(2) + weight(2) + port(2) +
453    /// target(variable)
454    pub fn encode_srv_rdata(&self) -> Vec<u8> {
455        let mut buf = Vec::new();
456        buf.extend_from_slice(&self.priority.to_be_bytes());
457        buf.extend_from_slice(&self.weight.to_be_bytes());
458        buf.extend_from_slice(&self.port.to_be_bytes());
459        // Target as DNS name labels
460        buf.extend_from_slice(&encode_dns_name(&self.target));
461        buf
462    }
463}
464
465// ============================================================================
466// Cache Entry
467// ============================================================================
468
469/// A cached mDNS record with expiry tracking
470#[derive(Debug, Clone, PartialEq, Eq)]
471#[cfg(feature = "alloc")]
472pub struct CacheEntry {
473    /// The resource record
474    pub record: MdnsRecord,
475    /// Tick count when this entry was inserted
476    pub inserted_tick: u64,
477    /// TTL in seconds at insertion time
478    pub original_ttl: u32,
479}
480
481#[cfg(feature = "alloc")]
482impl CacheEntry {
483    /// Check if this entry has expired given the current tick and
484    /// ticks-per-second
485    pub fn is_expired(&self, current_tick: u64, ticks_per_sec: u64) -> bool {
486        if ticks_per_sec == 0 {
487            return false;
488        }
489        let elapsed_ticks = current_tick.saturating_sub(self.inserted_tick);
490        let elapsed_secs = elapsed_ticks / ticks_per_sec;
491        elapsed_secs >= self.original_ttl as u64
492    }
493
494    /// Remaining TTL in seconds
495    pub fn remaining_ttl(&self, current_tick: u64, ticks_per_sec: u64) -> u32 {
496        if ticks_per_sec == 0 {
497            return self.original_ttl;
498        }
499        let elapsed_ticks = current_tick.saturating_sub(self.inserted_tick);
500        let elapsed_secs = elapsed_ticks / ticks_per_sec;
501        self.original_ttl.saturating_sub(elapsed_secs as u32)
502    }
503}
504
505// ============================================================================
506// Probe State Machine
507// ============================================================================
508
509/// State of a name-claiming probe sequence
510#[derive(Debug, Clone, Copy, PartialEq, Eq)]
511pub enum ProbeState {
512    /// Not yet started
513    Idle,
514    /// Sending probe queries (count, next_send_tick)
515    Probing { sent: u8, next_tick: u64 },
516    /// Sending announcements
517    Announcing { sent: u8, next_tick: u64 },
518    /// Name successfully claimed
519    Claimed,
520    /// Conflict detected, must choose new name
521    Conflict,
522}
523
524impl ProbeState {
525    /// Whether the name is fully claimed and usable
526    pub fn is_claimed(&self) -> bool {
527        matches!(self, ProbeState::Claimed)
528    }
529
530    /// Whether a conflict was detected
531    pub fn is_conflict(&self) -> bool {
532        matches!(self, ProbeState::Conflict)
533    }
534}
535
536// ============================================================================
537// mDNS Responder
538// ============================================================================
539
540/// Core mDNS responder managing registered services and cached records
541#[cfg(feature = "alloc")]
542pub struct MdnsResponder {
543    /// Our hostname (without .local suffix)
544    hostname: String,
545    /// Registered service instances
546    services: Vec<ServiceInstance>,
547    /// Record cache (name -> list of records)
548    cache: BTreeMap<String, Vec<CacheEntry>>,
549    /// Probe state per name being claimed
550    probe_states: BTreeMap<String, ProbeState>,
551    /// Monotonic tick counter reference (ticks per second)
552    ticks_per_sec: u64,
553    /// Host IPv4 address (4 bytes)
554    host_ipv4: [u8; 4],
555    /// Host IPv6 address (16 bytes)
556    host_ipv6: [u8; 16],
557}
558
559#[cfg(feature = "alloc")]
560impl MdnsResponder {
561    /// Create a new mDNS responder
562    pub fn new(hostname: &str, ticks_per_sec: u64) -> Self {
563        Self {
564            hostname: String::from(hostname),
565            services: Vec::new(),
566            cache: BTreeMap::new(),
567            probe_states: BTreeMap::new(),
568            ticks_per_sec,
569            host_ipv4: [0; 4],
570            host_ipv6: [0; 16],
571        }
572    }
573
574    /// Set the host IPv4 address
575    pub fn set_ipv4(&mut self, addr: [u8; 4]) {
576        self.host_ipv4 = addr;
577    }
578
579    /// Set the host IPv6 address
580    pub fn set_ipv6(&mut self, addr: [u8; 16]) {
581        self.host_ipv6 = addr;
582    }
583
584    /// Get our fully-qualified hostname (hostname.local)
585    pub fn fqdn(&self) -> String {
586        let mut name = self.hostname.clone();
587        name.push_str(LOCAL_SUFFIX);
588        name
589    }
590
591    // ---- Service Registration ----
592
593    /// Register a new service instance
594    pub fn register_service(&mut self, svc: ServiceInstance) -> Result<(), MdnsError> {
595        if self.services.len() >= MAX_SERVICES {
596            return Err(MdnsError::TooManyServices);
597        }
598        // Start probing for the service name
599        let full_name = svc.full_name();
600        self.probe_states.insert(full_name, ProbeState::Idle);
601        self.services.push(svc);
602        Ok(())
603    }
604
605    /// Deregister a service by instance name, returning goodbye records
606    pub fn deregister_service(&mut self, instance_name: &str) -> Vec<MdnsRecord> {
607        let mut goodbyes = Vec::new();
608        let mut names_to_remove = Vec::new();
609
610        self.services.retain(|svc| {
611            if svc.instance_name == instance_name {
612                // Generate goodbye packets (TTL=0) for PTR, SRV, TXT
613                goodbyes.push(MdnsRecord {
614                    name: svc.ptr_name(),
615                    rtype: MdnsRecordType::PTR,
616                    cache_flush: false,
617                    ttl: 0,
618                    rdata: encode_dns_name(&svc.full_name()),
619                });
620                goodbyes.push(MdnsRecord {
621                    name: svc.full_name(),
622                    rtype: MdnsRecordType::SRV,
623                    cache_flush: true,
624                    ttl: 0,
625                    rdata: svc.encode_srv_rdata(),
626                });
627                goodbyes.push(MdnsRecord {
628                    name: svc.full_name(),
629                    rtype: MdnsRecordType::TXT,
630                    cache_flush: true,
631                    ttl: 0,
632                    rdata: svc.txt.encode(),
633                });
634                names_to_remove.push(svc.full_name());
635                false // remove from list
636            } else {
637                true // keep
638            }
639        });
640
641        for name in &names_to_remove {
642            self.probe_states.remove(name);
643        }
644
645        goodbyes
646    }
647
648    /// Browse for services of a given type in our cache
649    pub fn browse_services(&self, service_type: &str) -> Vec<&ServiceInstance> {
650        // Check our own registered services
651        let mut results = Vec::new();
652        for svc in &self.services {
653            if svc.service_type.to_service_string() == service_type {
654                results.push(svc);
655            }
656        }
657        results
658    }
659
660    /// Look up a service instance by full name
661    pub fn lookup_service(&self, full_name: &str) -> Option<&ServiceInstance> {
662        self.services
663            .iter()
664            .find(|svc| svc.full_name() == full_name)
665    }
666
667    // ---- Probing & Announcing ----
668
669    /// Start probing for a name at the given tick
670    pub fn start_probe(&mut self, name: &str, current_tick: u64) {
671        let interval_ticks = (PROBE_INTERVAL_MS * self.ticks_per_sec) / 1000;
672        self.probe_states.insert(
673            String::from(name),
674            ProbeState::Probing {
675                sent: 1, // First probe is sent immediately on start
676                next_tick: current_tick + interval_ticks,
677            },
678        );
679    }
680
681    /// Advance probe/announce state machine; returns names that need packets
682    /// sent
683    pub fn tick_probes(&mut self, current_tick: u64) -> Vec<(String, ProbeState)> {
684        let mut actions = Vec::new();
685        let probe_interval = (PROBE_INTERVAL_MS * self.ticks_per_sec) / 1000;
686        let announce_interval = (ANNOUNCE_INTERVAL_MS * self.ticks_per_sec) / 1000;
687
688        for (name, state) in self.probe_states.iter_mut() {
689            match *state {
690                ProbeState::Probing { sent, next_tick } if current_tick >= next_tick => {
691                    if sent + 1 >= PROBE_COUNT {
692                        // Probing complete, start announcing
693                        *state = ProbeState::Announcing {
694                            sent: 0,
695                            next_tick: current_tick + announce_interval,
696                        };
697                    } else {
698                        *state = ProbeState::Probing {
699                            sent: sent + 1,
700                            next_tick: current_tick + probe_interval,
701                        };
702                    }
703                    actions.push((name.clone(), *state));
704                }
705                ProbeState::Announcing { sent, next_tick } if current_tick >= next_tick => {
706                    if sent + 1 >= ANNOUNCE_COUNT {
707                        *state = ProbeState::Claimed;
708                    } else {
709                        *state = ProbeState::Announcing {
710                            sent: sent + 1,
711                            next_tick: current_tick + announce_interval,
712                        };
713                    }
714                    actions.push((name.clone(), *state));
715                }
716                _ => {}
717            }
718        }
719        actions
720    }
721
722    /// Mark a name as conflicted
723    pub fn mark_conflict(&mut self, name: &str) {
724        if let Some(state) = self.probe_states.get_mut(name) {
725            *state = ProbeState::Conflict;
726        }
727    }
728
729    // ---- Query Handling ----
730
731    /// Check if a question matches any of our registered records
732    pub fn has_answer(&self, question: &MdnsQuestion) -> bool {
733        let qname = &question.name;
734        let fqdn = self.fqdn();
735
736        match question.qtype {
737            MdnsRecordType::A | MdnsRecordType::ANY => {
738                if qname == &fqdn {
739                    return true;
740                }
741            }
742            MdnsRecordType::AAAA => {
743                if qname == &fqdn {
744                    return true;
745                }
746            }
747            MdnsRecordType::PTR => {
748                for svc in &self.services {
749                    if qname == &svc.ptr_name() {
750                        return true;
751                    }
752                }
753            }
754            MdnsRecordType::SRV | MdnsRecordType::TXT => {
755                for svc in &self.services {
756                    if qname == &svc.full_name() {
757                        return true;
758                    }
759                }
760            }
761            _ => {}
762        }
763        false
764    }
765
766    /// Generate answer records for a given question
767    pub fn answer(&self, question: &MdnsQuestion) -> Vec<MdnsRecord> {
768        let mut answers = Vec::new();
769        let qname = &question.name;
770        let fqdn = self.fqdn();
771
772        // A record for our hostname
773        if (question.qtype == MdnsRecordType::A || question.qtype == MdnsRecordType::ANY)
774            && qname == &fqdn
775        {
776            answers.push(MdnsRecord {
777                name: fqdn.clone(),
778                rtype: MdnsRecordType::A,
779                cache_flush: true,
780                ttl: TTL_UNIQUE,
781                rdata: self.host_ipv4.to_vec(),
782            });
783        }
784
785        // AAAA record for our hostname
786        if (question.qtype == MdnsRecordType::AAAA || question.qtype == MdnsRecordType::ANY)
787            && qname == &fqdn
788        {
789            answers.push(MdnsRecord {
790                name: fqdn.clone(),
791                rtype: MdnsRecordType::AAAA,
792                cache_flush: true,
793                ttl: TTL_UNIQUE,
794                rdata: self.host_ipv6.to_vec(),
795            });
796        }
797
798        // PTR records for service enumeration
799        if question.qtype == MdnsRecordType::PTR || question.qtype == MdnsRecordType::ANY {
800            for svc in &self.services {
801                if qname == &svc.ptr_name() {
802                    answers.push(MdnsRecord {
803                        name: svc.ptr_name(),
804                        rtype: MdnsRecordType::PTR,
805                        cache_flush: false,
806                        ttl: TTL_SHARED,
807                        rdata: encode_dns_name(&svc.full_name()),
808                    });
809                }
810            }
811        }
812
813        // SRV records for service instances
814        if question.qtype == MdnsRecordType::SRV || question.qtype == MdnsRecordType::ANY {
815            for svc in &self.services {
816                if qname == &svc.full_name() {
817                    answers.push(MdnsRecord {
818                        name: svc.full_name(),
819                        rtype: MdnsRecordType::SRV,
820                        cache_flush: true,
821                        ttl: TTL_UNIQUE,
822                        rdata: svc.encode_srv_rdata(),
823                    });
824                }
825            }
826        }
827
828        // TXT records for service instances
829        if question.qtype == MdnsRecordType::TXT || question.qtype == MdnsRecordType::ANY {
830            for svc in &self.services {
831                if qname == &svc.full_name() {
832                    answers.push(MdnsRecord {
833                        name: svc.full_name(),
834                        rtype: MdnsRecordType::TXT,
835                        cache_flush: true,
836                        ttl: TTL_UNIQUE,
837                        rdata: svc.txt.encode(),
838                    });
839                }
840            }
841        }
842
843        answers
844    }
845
846    // ---- Known-Answer Suppression ----
847
848    /// Filter out answers that the querier already knows (known-answer
849    /// suppression) Per RFC 6762 section 7.1: if the querier includes a
850    /// known answer with TTL >= 50% of our TTL, we suppress that answer.
851    pub fn suppress_known_answers(
852        &self,
853        answers: &mut Vec<MdnsRecord>,
854        known_answers: &[MdnsRecord],
855    ) {
856        answers.retain(|answer| {
857            !known_answers.iter().any(|ka| {
858                ka.name == answer.name
859                    && ka.rtype == answer.rtype
860                    && ka.rdata == answer.rdata
861                    && ka.ttl >= answer.ttl / 2
862            })
863        });
864    }
865
866    // ---- Conflict Detection ----
867
868    /// Check if an incoming record conflicts with any of our registrations
869    pub fn detect_conflict(&self, record: &MdnsRecord) -> bool {
870        let fqdn = self.fqdn();
871
872        // Check hostname conflict
873        if record.name == fqdn
874            && (record.rtype == MdnsRecordType::A || record.rtype == MdnsRecordType::AAAA)
875            && record.rdata != self.host_ipv4.to_vec()
876            && record.rdata != self.host_ipv6.to_vec()
877        {
878            return true;
879        }
880
881        // Check service name conflicts
882        for svc in &self.services {
883            if record.name == svc.full_name()
884                && record.rtype == MdnsRecordType::SRV
885                && record.rdata != svc.encode_srv_rdata()
886            {
887                return true;
888            }
889        }
890
891        false
892    }
893
894    // ---- Cache Management ----
895
896    /// Insert a record into the cache
897    pub fn cache_insert(&mut self, record: MdnsRecord, current_tick: u64) -> Result<(), MdnsError> {
898        // Goodbye packet (TTL=0): remove from cache
899        if record.ttl == 0 {
900            self.cache_remove(&record.name, record.rtype);
901            return Ok(());
902        }
903
904        let ttl = record.ttl;
905        let entry = CacheEntry {
906            original_ttl: ttl,
907            inserted_tick: current_tick,
908            record,
909        };
910
911        // Check total cache size before inserting
912        let total: usize = self.cache.values().map(|v| v.len()).sum();
913
914        let entries = self.cache.entry(entry.record.name.clone()).or_default();
915
916        // If cache-flush bit is set, remove all records of the same type
917        if entry.record.cache_flush {
918            entries.retain(|e| e.record.rtype != entry.record.rtype);
919        }
920
921        // Replace existing identical record or add new
922        if let Some(existing) = entries
923            .iter_mut()
924            .find(|e| e.record.rtype == entry.record.rtype && e.record.rdata == entry.record.rdata)
925        {
926            *existing = entry;
927        } else {
928            if total >= MAX_CACHE_ENTRIES {
929                return Err(MdnsError::CacheFull);
930            }
931            entries.push(entry);
932        }
933
934        Ok(())
935    }
936
937    /// Remove records from cache by name and type
938    pub fn cache_remove(&mut self, name: &str, rtype: MdnsRecordType) {
939        if let Some(entries) = self.cache.get_mut(name) {
940            entries.retain(|e| e.record.rtype != rtype);
941            if entries.is_empty() {
942                self.cache.remove(name);
943            }
944        }
945    }
946
947    /// Look up cached records by name and type
948    pub fn cache_lookup(
949        &self,
950        name: &str,
951        rtype: MdnsRecordType,
952        current_tick: u64,
953    ) -> Vec<&CacheEntry> {
954        match self.cache.get(name) {
955            Some(entries) => entries
956                .iter()
957                .filter(|e| {
958                    (rtype == MdnsRecordType::ANY || e.record.rtype == rtype)
959                        && !e.is_expired(current_tick, self.ticks_per_sec)
960                })
961                .collect(),
962            None => Vec::new(),
963        }
964    }
965
966    /// Evict expired entries from the cache
967    pub fn cache_evict_expired(&mut self, current_tick: u64) {
968        let tps = self.ticks_per_sec;
969        self.cache.retain(|_name, entries| {
970            entries.retain(|e| !e.is_expired(current_tick, tps));
971            !entries.is_empty()
972        });
973    }
974
975    /// Get the number of cached entries
976    pub fn cache_size(&self) -> usize {
977        self.cache.values().map(|v| v.len()).sum()
978    }
979}
980
981// ============================================================================
982// Name Resolution Helpers
983// ============================================================================
984
985/// Check if a name is in the `.local` domain
986pub fn is_local_name(name: &str) -> bool {
987    // Check suffix ".local" or ".local."
988    let name_trimmed = name.strip_suffix('.').unwrap_or(name);
989    if name_trimmed.len() < 5 {
990        return false;
991    }
992    // Case-insensitive check for ".local" suffix or bare "local"
993    let bytes = name_trimmed.as_bytes();
994    let len = bytes.len();
995    if len >= 6 {
996        let suffix = &bytes[len - 6..];
997        suffix.eq_ignore_ascii_case(b".local")
998    } else {
999        // Exactly "local" (len == 5)
1000        bytes.eq_ignore_ascii_case(b"local")
1001    }
1002}
1003
1004/// Check if a name is a reverse lookup address (.arpa)
1005pub fn is_reverse_lookup(name: &str) -> bool {
1006    let name_trimmed = name.trim_end_matches('.');
1007    name_trimmed.ends_with(".in-addr.arpa") || name_trimmed.ends_with(".ip6.arpa")
1008}
1009
1010/// Build a reverse lookup name for an IPv4 address
1011#[cfg(feature = "alloc")]
1012pub fn ipv4_reverse_name(addr: [u8; 4]) -> String {
1013    let mut name = String::new();
1014    // Reverse octets
1015    for i in (0..4).rev() {
1016        if !name.is_empty() {
1017            name.push('.');
1018        }
1019        // Manual integer formatting to avoid format! in no_std hot paths
1020        let mut val = addr[i];
1021        if val >= 100 {
1022            name.push((b'0' + val / 100) as char);
1023            val %= 100;
1024            name.push((b'0' + val / 10) as char);
1025            name.push((b'0' + val % 10) as char);
1026        } else if val >= 10 {
1027            name.push((b'0' + val / 10) as char);
1028            name.push((b'0' + val % 10) as char);
1029        } else {
1030            name.push((b'0' + val) as char);
1031        }
1032    }
1033    name.push_str(".in-addr.arpa");
1034    name
1035}
1036
1037// ============================================================================
1038// DNS Name Encoding/Decoding
1039// ============================================================================
1040
1041/// Encode a domain name to DNS wire format (label-length prefixed)
1042#[cfg(feature = "alloc")]
1043pub fn encode_dns_name(name: &str) -> Vec<u8> {
1044    let mut buf = Vec::new();
1045    let name = name.trim_end_matches('.');
1046    if name.is_empty() {
1047        buf.push(0);
1048        return buf;
1049    }
1050    for label in name.split('.') {
1051        let len = label.len();
1052        if len > MAX_LABEL_LEN {
1053            // Truncate oversized labels
1054            buf.push(MAX_LABEL_LEN as u8);
1055            buf.extend_from_slice(&label.as_bytes()[..MAX_LABEL_LEN]);
1056        } else {
1057            buf.push(len as u8);
1058            buf.extend_from_slice(label.as_bytes());
1059        }
1060    }
1061    buf.push(0); // Root label
1062    buf
1063}
1064
1065/// Decode a DNS wire-format name from a buffer at the given offset.
1066/// Returns the decoded name and the number of bytes consumed from `offset`.
1067#[cfg(feature = "alloc")]
1068pub fn decode_dns_name(buf: &[u8], offset: usize) -> Result<(String, usize), MdnsError> {
1069    let mut name = String::new();
1070    let mut pos = offset;
1071    let mut consumed = 0;
1072    let mut followed_pointer = false;
1073    let mut jumps = 0;
1074
1075    loop {
1076        if pos >= buf.len() {
1077            return Err(MdnsError::MessageTooShort);
1078        }
1079        let len = buf[pos] as usize;
1080
1081        if len == 0 {
1082            if !followed_pointer {
1083                consumed = pos - offset + 1;
1084            }
1085            break;
1086        }
1087
1088        // Compression pointer
1089        if len & 0xC0 == 0xC0 {
1090            if pos + 1 >= buf.len() {
1091                return Err(MdnsError::MessageTooShort);
1092            }
1093            if !followed_pointer {
1094                consumed = pos - offset + 2;
1095                followed_pointer = true;
1096            }
1097            let ptr = ((len & 0x3F) << 8) | (buf[pos + 1] as usize);
1098            pos = ptr;
1099            jumps += 1;
1100            if jumps > MAX_NAME_LEN {
1101                return Err(MdnsError::InvalidLabel);
1102            }
1103            continue;
1104        }
1105
1106        if len > MAX_LABEL_LEN {
1107            return Err(MdnsError::InvalidLabel);
1108        }
1109        pos += 1;
1110        if pos + len > buf.len() {
1111            return Err(MdnsError::MessageTooShort);
1112        }
1113
1114        if !name.is_empty() {
1115            name.push('.');
1116        }
1117        let label =
1118            core::str::from_utf8(&buf[pos..pos + len]).map_err(|_| MdnsError::InvalidLabel)?;
1119        name.push_str(label);
1120        pos += len;
1121
1122        if name.len() > MAX_NAME_LEN {
1123            return Err(MdnsError::NameTooLong);
1124        }
1125    }
1126
1127    Ok((name, consumed))
1128}
1129
1130/// Build an mDNS probe query message for the given name (ANY type, QU bit set)
1131#[cfg(feature = "alloc")]
1132pub fn build_probe_query(name: &str) -> Vec<u8> {
1133    let mut buf = Vec::with_capacity(MAX_MDNS_MSG_SIZE);
1134
1135    // Header: ID=0, QR=0 (query), QDCOUNT=1
1136    buf.extend_from_slice(&[0u8; 2]); // ID = 0 for mDNS
1137    buf.extend_from_slice(&[0x00, 0x00]); // Flags: standard query
1138    buf.extend_from_slice(&1u16.to_be_bytes()); // QDCOUNT = 1
1139    buf.extend_from_slice(&[0u8; 2]); // ANCOUNT = 0
1140    buf.extend_from_slice(&[0u8; 2]); // NSCOUNT = 0
1141    buf.extend_from_slice(&[0u8; 2]); // ARCOUNT = 0
1142
1143    // Question: name, type=ANY, class=IN|QU
1144    buf.extend_from_slice(&encode_dns_name(name));
1145    buf.extend_from_slice(&MdnsRecordType::ANY.to_u16().to_be_bytes());
1146    let class_qu = MdnsClass::IN as u16 | QU_BIT;
1147    buf.extend_from_slice(&class_qu.to_be_bytes());
1148
1149    buf
1150}
1151
1152// ============================================================================
1153// Trait: eq_ignore_ascii_case for byte slices
1154// ============================================================================
1155
1156trait EqIgnoreAsciiCase {
1157    fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool;
1158}
1159
1160impl EqIgnoreAsciiCase for [u8] {
1161    fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool {
1162        if self.len() != other.len() {
1163            return false;
1164        }
1165        for i in 0..self.len() {
1166            if !self[i].eq_ignore_ascii_case(&other[i]) {
1167                return false;
1168            }
1169        }
1170        true
1171    }
1172}
1173
1174impl EqIgnoreAsciiCase for str {
1175    fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool {
1176        self.as_bytes().eq_ignore_ascii_case(other)
1177    }
1178}
1179
1180// ============================================================================
1181// Tests
1182// ============================================================================
1183
1184#[cfg(test)]
1185mod tests {
1186    #[allow(unused_imports)]
1187    use alloc::vec;
1188
1189    use super::*;
1190
1191    // ---- Service type parsing ----
1192
1193    #[test]
1194    fn test_parse_http_tcp_local() {
1195        let st = ServiceType::parse("_http._tcp.local").unwrap();
1196        assert_eq!(st.service, "http");
1197        assert_eq!(st.protocol, "tcp");
1198        assert_eq!(st.domain, "local");
1199    }
1200
1201    #[test]
1202    fn test_parse_ssh_tcp_local() {
1203        let st = ServiceType::parse("_ssh._tcp.local").unwrap();
1204        assert_eq!(st.service, "ssh");
1205        assert_eq!(st.protocol, "tcp");
1206        assert_eq!(st.domain, "local");
1207    }
1208
1209    #[test]
1210    fn test_parse_udp_service() {
1211        let st = ServiceType::parse("_tftp._udp.local").unwrap();
1212        assert_eq!(st.protocol, "udp");
1213    }
1214
1215    #[test]
1216    fn test_parse_invalid_no_underscore() {
1217        assert_eq!(
1218            ServiceType::parse("http._tcp.local"),
1219            Err(MdnsError::InvalidServiceType)
1220        );
1221    }
1222
1223    #[test]
1224    fn test_parse_invalid_protocol() {
1225        assert_eq!(
1226            ServiceType::parse("_http._sctp.local"),
1227            Err(MdnsError::InvalidServiceType)
1228        );
1229    }
1230
1231    #[test]
1232    fn test_service_type_roundtrip() {
1233        let st = ServiceType::parse("_http._tcp.local").unwrap();
1234        assert_eq!(st.to_service_string(), "_http._tcp.local");
1235    }
1236
1237    // ---- TXT record encoding/decoding ----
1238
1239    #[test]
1240    fn test_txt_encode_decode() {
1241        let mut txt = TxtRecord::new();
1242        txt.add("path", "/index.html").unwrap();
1243        txt.add("version", "1").unwrap();
1244
1245        let encoded = txt.encode();
1246        let decoded = TxtRecord::decode(&encoded).unwrap();
1247
1248        assert_eq!(decoded.entries.len(), 2);
1249        assert_eq!(decoded.get("path"), Some("/index.html"));
1250        assert_eq!(decoded.get("version"), Some("1"));
1251    }
1252
1253    #[test]
1254    fn test_txt_empty() {
1255        let txt = TxtRecord::new();
1256        let encoded = txt.encode();
1257        assert_eq!(encoded, vec![0u8]); // Single zero byte per RFC 6763
1258    }
1259
1260    #[test]
1261    fn test_txt_boolean_key() {
1262        let mut txt = TxtRecord::new();
1263        txt.add("paper", "").unwrap();
1264        let encoded = txt.encode();
1265        let decoded = TxtRecord::decode(&encoded).unwrap();
1266        assert_eq!(decoded.entries[0].key, "paper");
1267        assert_eq!(decoded.entries[0].value, "");
1268    }
1269
1270    // ---- SRV record ----
1271
1272    #[test]
1273    fn test_srv_record_construction() {
1274        let svc = ServiceInstance {
1275            instance_name: String::from("My Web Server"),
1276            service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1277            port: 8080,
1278            target: String::from("myhost.local"),
1279            priority: 0,
1280            weight: 0,
1281            txt: TxtRecord::new(),
1282        };
1283        let srv = svc.to_srv();
1284        assert_eq!(srv.port, 8080);
1285        assert_eq!(srv.priority, 0);
1286        assert_eq!(srv.target, "myhost.local");
1287    }
1288
1289    #[test]
1290    fn test_srv_rdata_encoding() {
1291        let svc = ServiceInstance {
1292            instance_name: String::from("Test"),
1293            service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1294            port: 80,
1295            target: String::from("host.local"),
1296            priority: 10,
1297            weight: 20,
1298            txt: TxtRecord::new(),
1299        };
1300        let rdata = svc.encode_srv_rdata();
1301        // priority=10 (2 bytes) + weight=20 (2 bytes) + port=80 (2 bytes) + name
1302        assert_eq!(rdata[0], 0);
1303        assert_eq!(rdata[1], 10); // priority
1304        assert_eq!(rdata[2], 0);
1305        assert_eq!(rdata[3], 20); // weight
1306        assert_eq!(rdata[4], 0);
1307        assert_eq!(rdata[5], 80); // port
1308                                  // Followed by encoded "host.local"
1309        assert_eq!(rdata[6], 4); // "host" label length
1310    }
1311
1312    // ---- PTR record for service enumeration ----
1313
1314    #[test]
1315    fn test_ptr_name_for_service() {
1316        let svc = ServiceInstance {
1317            instance_name: String::from("My Printer"),
1318            service_type: ServiceType::parse("_ipp._tcp.local").unwrap(),
1319            port: 631,
1320            target: String::from("printer.local"),
1321            priority: 0,
1322            weight: 0,
1323            txt: TxtRecord::new(),
1324        };
1325        assert_eq!(svc.ptr_name(), "_ipp._tcp.local");
1326        assert_eq!(svc.full_name(), "My Printer._ipp._tcp.local");
1327    }
1328
1329    // ---- Conflict detection ----
1330
1331    #[test]
1332    fn test_conflict_same_name_different_data() {
1333        let mut resp = MdnsResponder::new("myhost", 1000);
1334        resp.set_ipv4([192, 168, 1, 10]);
1335
1336        let conflicting = MdnsRecord {
1337            name: String::from("myhost.local"),
1338            rtype: MdnsRecordType::A,
1339            cache_flush: true,
1340            ttl: TTL_UNIQUE,
1341            rdata: vec![192, 168, 1, 99], // Different IP
1342        };
1343        assert!(resp.detect_conflict(&conflicting));
1344    }
1345
1346    #[test]
1347    fn test_no_conflict_same_data() {
1348        let mut resp = MdnsResponder::new("myhost", 1000);
1349        resp.set_ipv4([192, 168, 1, 10]);
1350
1351        let same = MdnsRecord {
1352            name: String::from("myhost.local"),
1353            rtype: MdnsRecordType::A,
1354            cache_flush: true,
1355            ttl: TTL_UNIQUE,
1356            rdata: vec![192, 168, 1, 10], // Same IP
1357        };
1358        assert!(!resp.detect_conflict(&same));
1359    }
1360
1361    // ---- Known-answer suppression ----
1362
1363    #[test]
1364    fn test_known_answer_suppression() {
1365        let resp = MdnsResponder::new("myhost", 1000);
1366
1367        let mut answers = vec![MdnsRecord {
1368            name: String::from("_http._tcp.local"),
1369            rtype: MdnsRecordType::PTR,
1370            cache_flush: false,
1371            ttl: TTL_SHARED,
1372            rdata: encode_dns_name("Server._http._tcp.local"),
1373        }];
1374
1375        let known = vec![MdnsRecord {
1376            name: String::from("_http._tcp.local"),
1377            rtype: MdnsRecordType::PTR,
1378            cache_flush: false,
1379            ttl: TTL_SHARED, // >= 50% of our TTL
1380            rdata: encode_dns_name("Server._http._tcp.local"),
1381        }];
1382
1383        resp.suppress_known_answers(&mut answers, &known);
1384        assert!(answers.is_empty(), "Answer should be suppressed");
1385    }
1386
1387    #[test]
1388    fn test_known_answer_not_suppressed_low_ttl() {
1389        let resp = MdnsResponder::new("myhost", 1000);
1390
1391        let mut answers = vec![MdnsRecord {
1392            name: String::from("_http._tcp.local"),
1393            rtype: MdnsRecordType::PTR,
1394            cache_flush: false,
1395            ttl: TTL_SHARED,
1396            rdata: encode_dns_name("Server._http._tcp.local"),
1397        }];
1398
1399        let known = vec![MdnsRecord {
1400            name: String::from("_http._tcp.local"),
1401            rtype: MdnsRecordType::PTR,
1402            cache_flush: false,
1403            ttl: 1, // Much less than 50% of TTL_SHARED
1404            rdata: encode_dns_name("Server._http._tcp.local"),
1405        }];
1406
1407        resp.suppress_known_answers(&mut answers, &known);
1408        assert_eq!(
1409            answers.len(),
1410            1,
1411            "Answer should NOT be suppressed with low TTL"
1412        );
1413    }
1414
1415    // ---- TTL expiry ----
1416
1417    #[test]
1418    fn test_cache_entry_expiry() {
1419        let entry = CacheEntry {
1420            record: MdnsRecord {
1421                name: String::from("host.local"),
1422                rtype: MdnsRecordType::A,
1423                cache_flush: true,
1424                ttl: 120,
1425                rdata: vec![10, 0, 0, 1],
1426            },
1427            inserted_tick: 0,
1428            original_ttl: 120,
1429        };
1430
1431        // 1000 ticks/sec, 119 seconds elapsed: not expired
1432        assert!(!entry.is_expired(119_000, 1000));
1433        assert_eq!(entry.remaining_ttl(119_000, 1000), 1);
1434
1435        // 120 seconds elapsed: expired
1436        assert!(entry.is_expired(120_000, 1000));
1437        assert_eq!(entry.remaining_ttl(120_000, 1000), 0);
1438    }
1439
1440    // ---- Goodbye packet generation ----
1441
1442    #[test]
1443    fn test_goodbye_packets_on_deregister() {
1444        let mut resp = MdnsResponder::new("myhost", 1000);
1445        let mut txt = TxtRecord::new();
1446        txt.add("path", "/").unwrap();
1447
1448        let svc = ServiceInstance {
1449            instance_name: String::from("WebSrv"),
1450            service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1451            port: 80,
1452            target: String::from("myhost.local"),
1453            priority: 0,
1454            weight: 0,
1455            txt,
1456        };
1457        resp.register_service(svc).unwrap();
1458
1459        let goodbyes = resp.deregister_service("WebSrv");
1460        assert_eq!(goodbyes.len(), 3); // PTR, SRV, TXT
1461        for g in &goodbyes {
1462            assert_eq!(g.ttl, 0, "Goodbye packets must have TTL=0");
1463        }
1464        assert_eq!(goodbyes[0].rtype, MdnsRecordType::PTR);
1465        assert_eq!(goodbyes[1].rtype, MdnsRecordType::SRV);
1466        assert_eq!(goodbyes[2].rtype, MdnsRecordType::TXT);
1467    }
1468
1469    // ---- Service registration and lookup ----
1470
1471    #[test]
1472    fn test_register_and_lookup_service() {
1473        let mut resp = MdnsResponder::new("myhost", 1000);
1474        let svc = ServiceInstance {
1475            instance_name: String::from("My SSH"),
1476            service_type: ServiceType::parse("_ssh._tcp.local").unwrap(),
1477            port: 22,
1478            target: String::from("myhost.local"),
1479            priority: 0,
1480            weight: 0,
1481            txt: TxtRecord::new(),
1482        };
1483        resp.register_service(svc).unwrap();
1484
1485        let found = resp.browse_services("_ssh._tcp.local");
1486        assert_eq!(found.len(), 1);
1487        assert_eq!(found[0].port, 22);
1488
1489        let by_name = resp.lookup_service("My SSH._ssh._tcp.local");
1490        assert!(by_name.is_some());
1491    }
1492
1493    // ---- .local suffix detection ----
1494
1495    #[test]
1496    fn test_is_local_name() {
1497        assert!(is_local_name("myhost.local"));
1498        assert!(is_local_name("myhost.local."));
1499        assert!(is_local_name("sub.myhost.local"));
1500        assert!(!is_local_name("myhost.com"));
1501        assert!(!is_local_name("myhost.localhost"));
1502    }
1503
1504    // ---- Probe message construction ----
1505
1506    #[test]
1507    fn test_build_probe_query() {
1508        let query = build_probe_query("myhost.local");
1509        // Check header: ID=0, flags=0, QDCOUNT=1, others=0
1510        assert_eq!(query[0], 0); // ID high
1511        assert_eq!(query[1], 0); // ID low
1512        assert_eq!(query[4], 0); // QDCOUNT high
1513        assert_eq!(query[5], 1); // QDCOUNT low = 1
1514        assert_eq!(query[6], 0); // ANCOUNT high
1515        assert_eq!(query[7], 0); // ANCOUNT low
1516
1517        // Question section starts at offset 12
1518        // First label: "myhost" (length 6)
1519        assert_eq!(query[12], 6);
1520        assert_eq!(&query[13..19], b"myhost");
1521        // Second label: "local" (length 5)
1522        assert_eq!(query[19], 5);
1523        assert_eq!(&query[20..25], b"local");
1524        // Root label
1525        assert_eq!(query[25], 0);
1526        // Type = ANY (255)
1527        assert_eq!(query[26], 0);
1528        assert_eq!(query[27], 255);
1529        // Class = IN | QU bit (0x8001)
1530        assert_eq!(query[28], 0x80);
1531        assert_eq!(query[29], 0x01);
1532    }
1533
1534    // ---- DNS name encoding ----
1535
1536    #[test]
1537    fn test_encode_dns_name() {
1538        let encoded = encode_dns_name("myhost.local");
1539        assert_eq!(encoded[0], 6); // "myhost" length
1540        assert_eq!(&encoded[1..7], b"myhost");
1541        assert_eq!(encoded[7], 5); // "local" length
1542        assert_eq!(&encoded[8..13], b"local");
1543        assert_eq!(encoded[13], 0); // root
1544    }
1545
1546    #[test]
1547    fn test_decode_dns_name() {
1548        let encoded = encode_dns_name("test.local");
1549        let (name, consumed) = decode_dns_name(&encoded, 0).unwrap();
1550        assert_eq!(name, "test.local");
1551        assert_eq!(consumed, encoded.len());
1552    }
1553
1554    // ---- Reverse lookup ----
1555
1556    #[test]
1557    fn test_is_reverse_lookup() {
1558        assert!(is_reverse_lookup("1.168.192.in-addr.arpa"));
1559        assert!(is_reverse_lookup("8.b.d.0.1.0.0.2.ip6.arpa"));
1560        assert!(!is_reverse_lookup("myhost.local"));
1561    }
1562
1563    #[test]
1564    fn test_ipv4_reverse_name() {
1565        let name = ipv4_reverse_name([192, 168, 1, 10]);
1566        assert_eq!(name, "10.1.168.192.in-addr.arpa");
1567    }
1568
1569    // ---- Cache operations ----
1570
1571    #[test]
1572    fn test_cache_insert_and_lookup() {
1573        let mut resp = MdnsResponder::new("myhost", 1000);
1574        let record = MdnsRecord {
1575            name: String::from("other.local"),
1576            rtype: MdnsRecordType::A,
1577            cache_flush: false,
1578            ttl: 120,
1579            rdata: vec![10, 0, 0, 2],
1580        };
1581        resp.cache_insert(record, 0).unwrap();
1582
1583        let results = resp.cache_lookup("other.local", MdnsRecordType::A, 0);
1584        assert_eq!(results.len(), 1);
1585        assert_eq!(results[0].record.rdata, vec![10, 0, 0, 2]);
1586    }
1587
1588    #[test]
1589    fn test_cache_goodbye_removes() {
1590        let mut resp = MdnsResponder::new("myhost", 1000);
1591        let record = MdnsRecord {
1592            name: String::from("gone.local"),
1593            rtype: MdnsRecordType::A,
1594            cache_flush: false,
1595            ttl: 120,
1596            rdata: vec![10, 0, 0, 3],
1597        };
1598        resp.cache_insert(record, 0).unwrap();
1599        assert_eq!(resp.cache_size(), 1);
1600
1601        // Goodbye packet
1602        let goodbye = MdnsRecord {
1603            name: String::from("gone.local"),
1604            rtype: MdnsRecordType::A,
1605            cache_flush: false,
1606            ttl: 0,
1607            rdata: vec![10, 0, 0, 3],
1608        };
1609        resp.cache_insert(goodbye, 1000).unwrap();
1610        assert_eq!(resp.cache_size(), 0);
1611    }
1612
1613    // ---- Probe state machine ----
1614
1615    #[test]
1616    fn test_probe_state_machine() {
1617        let mut resp = MdnsResponder::new("myhost", 1000);
1618        resp.start_probe("myhost.local", 0);
1619
1620        // Probe interval = 250ms * 1000 tps = 250 ticks
1621        let actions = resp.tick_probes(250);
1622        assert_eq!(actions.len(), 1);
1623
1624        let actions = resp.tick_probes(500);
1625        assert_eq!(actions.len(), 1);
1626        // After 3rd probe -> Announcing
1627        match actions[0].1 {
1628            ProbeState::Announcing { .. } => {}
1629            _ => panic!("Expected Announcing state after 3 probes"),
1630        }
1631    }
1632
1633    // ---- Query answering ----
1634
1635    #[test]
1636    fn test_answer_a_record() {
1637        let mut resp = MdnsResponder::new("myhost", 1000);
1638        resp.set_ipv4([192, 168, 1, 42]);
1639
1640        let q = MdnsQuestion {
1641            name: String::from("myhost.local"),
1642            qtype: MdnsRecordType::A,
1643            unicast: false,
1644        };
1645
1646        assert!(resp.has_answer(&q));
1647        let answers = resp.answer(&q);
1648        assert_eq!(answers.len(), 1);
1649        assert_eq!(answers[0].rtype, MdnsRecordType::A);
1650        assert_eq!(answers[0].rdata, vec![192, 168, 1, 42]);
1651        assert_eq!(answers[0].ttl, TTL_UNIQUE);
1652    }
1653}