1#![allow(dead_code)]
17
18#[cfg(feature = "alloc")]
19use alloc::{collections::BTreeMap, string::String, vec::Vec};
20
21pub const MDNS_IPV4_ADDR: [u8; 4] = [224, 0, 0, 251];
27
28pub const MDNS_IPV6_ADDR: [u8; 16] = [0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfb];
30
31pub const MDNS_PORT: u16 = 5353;
33
34pub const TTL_UNIQUE: u32 = 120;
36
37pub const TTL_SHARED: u32 = 4500;
39
40pub const PROBE_INTERVAL_MS: u64 = 250;
42
43pub const PROBE_COUNT: u8 = 3;
45
46pub const ANNOUNCE_INTERVAL_MS: u64 = 1000;
48
49pub const ANNOUNCE_COUNT: u8 = 2;
51
52const MAX_MDNS_MSG_SIZE: usize = 512;
54
55const MAX_LABEL_LEN: usize = 63;
57
58const MAX_NAME_LEN: usize = 255;
60
61const MAX_CACHE_ENTRIES: usize = 512;
63
64const MAX_SERVICES: usize = 64;
66
67const MAX_TXT_PAIRS: usize = 16;
69
70const QU_BIT: u16 = 0x8000;
72
73const CACHE_FLUSH_BIT: u16 = 0x8000;
75
76pub const LOCAL_SUFFIX: &str = ".local";
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
85#[repr(u16)]
86pub enum MdnsRecordType {
87 A = 1,
89 PTR = 12,
91 TXT = 16,
93 AAAA = 28,
95 SRV = 33,
97 ANY = 255,
99 Unknown = 0,
101}
102
103impl MdnsRecordType {
104 pub fn from_u16(val: u16) -> Self {
105 match val {
106 1 => Self::A,
107 12 => Self::PTR,
108 16 => Self::TXT,
109 28 => Self::AAAA,
110 33 => Self::SRV,
111 255 => Self::ANY,
112 _ => Self::Unknown,
113 }
114 }
115
116 pub fn to_u16(self) -> u16 {
117 self as u16
118 }
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123#[repr(u16)]
124pub enum MdnsClass {
125 IN = 1,
127 ANY = 255,
129}
130
131impl MdnsClass {
132 pub fn from_u16(val: u16) -> Self {
133 match val & !QU_BIT {
134 1 => Self::IN,
135 255 => Self::ANY,
136 _ => Self::IN,
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum MdnsError {
148 MessageTooShort,
150 InvalidLabel,
152 NameTooLong,
154 BufferTooSmall,
156 TooManyServices,
158 NameConflict,
160 CacheFull,
162 InvalidServiceType,
164 TxtTooLarge,
166 NotFound,
168 InvalidFormat,
170}
171
172#[derive(Debug, Clone, PartialEq, Eq)]
178#[cfg(feature = "alloc")]
179pub struct SrvRecord {
180 pub priority: u16,
182 pub weight: u16,
184 pub port: u16,
186 pub target: String,
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
192#[cfg(feature = "alloc")]
193pub struct TxtRecord {
194 pub entries: Vec<TxtEntry>,
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
200#[cfg(feature = "alloc")]
201pub struct TxtEntry {
202 pub key: String,
204 pub value: String,
206}
207
208#[cfg(feature = "alloc")]
209impl Default for TxtRecord {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[cfg(feature = "alloc")]
216impl TxtRecord {
217 pub fn new() -> Self {
219 Self {
220 entries: Vec::new(),
221 }
222 }
223
224 pub fn add(&mut self, key: &str, value: &str) -> Result<(), MdnsError> {
226 if self.entries.len() >= MAX_TXT_PAIRS {
227 return Err(MdnsError::TxtTooLarge);
228 }
229 if key.len() + 1 + value.len() > 255 {
231 return Err(MdnsError::TxtTooLarge);
232 }
233 self.entries.push(TxtEntry {
234 key: String::from(key),
235 value: String::from(value),
236 });
237 Ok(())
238 }
239
240 pub fn encode(&self) -> Vec<u8> {
242 let mut buf = Vec::new();
243 for entry in &self.entries {
244 let s = if entry.value.is_empty() {
245 entry.key.clone()
246 } else {
247 let mut s = entry.key.clone();
248 s.push('=');
249 s.push_str(&entry.value);
250 s
251 };
252 let len = s.len();
253 if len <= 255 {
254 buf.push(len as u8);
255 buf.extend_from_slice(s.as_bytes());
256 }
257 }
258 if buf.is_empty() {
260 buf.push(0);
261 }
262 buf
263 }
264
265 pub fn decode(data: &[u8]) -> Result<Self, MdnsError> {
267 let mut entries = Vec::new();
268 let mut pos = 0;
269 while pos < data.len() {
270 let len = data[pos] as usize;
271 pos += 1;
272 if len == 0 {
273 continue;
274 }
275 if pos + len > data.len() {
276 return Err(MdnsError::InvalidFormat);
277 }
278 let s = core::str::from_utf8(&data[pos..pos + len])
279 .map_err(|_| MdnsError::InvalidFormat)?;
280 pos += len;
281 if let Some(eq_pos) = s.find('=') {
282 entries.push(TxtEntry {
283 key: String::from(&s[..eq_pos]),
284 value: String::from(&s[eq_pos + 1..]),
285 });
286 } else {
287 entries.push(TxtEntry {
288 key: String::from(s),
289 value: String::new(),
290 });
291 }
292 }
293 Ok(Self { entries })
294 }
295
296 pub fn get(&self, key: &str) -> Option<&str> {
298 for entry in &self.entries {
299 if entry.key == key {
300 return Some(&entry.value);
301 }
302 }
303 None
304 }
305}
306
307#[derive(Debug, Clone, PartialEq, Eq)]
313#[cfg(feature = "alloc")]
314pub struct MdnsRecord {
315 pub name: String,
317 pub rtype: MdnsRecordType,
319 pub cache_flush: bool,
321 pub ttl: u32,
323 pub rdata: Vec<u8>,
325}
326
327#[derive(Debug, Clone, PartialEq, Eq)]
329#[cfg(feature = "alloc")]
330pub struct MdnsQuestion {
331 pub name: String,
333 pub qtype: MdnsRecordType,
335 pub unicast: bool,
337}
338
339#[derive(Debug, Clone, PartialEq, Eq)]
345#[cfg(feature = "alloc")]
346pub struct ServiceType {
347 pub service: String,
349 pub protocol: String,
351 pub domain: String,
353}
354
355#[cfg(feature = "alloc")]
356impl ServiceType {
357 pub fn parse(s: &str) -> Result<Self, MdnsError> {
359 let s = s.trim_end_matches('.');
360 let parts: Vec<&str> = s.split('.').collect();
361 if parts.len() < 3 {
362 return Err(MdnsError::InvalidServiceType);
363 }
364 let service_part = parts[0];
365 let proto_part = parts[1];
366 let mut domain = String::new();
367 for (i, part) in parts[2..].iter().enumerate() {
368 if i > 0 {
369 domain.push('.');
370 }
371 domain.push_str(part);
372 }
373
374 if !service_part.starts_with('_') || service_part.len() < 2 {
376 return Err(MdnsError::InvalidServiceType);
377 }
378 if proto_part != "_tcp" && proto_part != "_udp" {
380 return Err(MdnsError::InvalidServiceType);
381 }
382
383 Ok(Self {
384 service: String::from(&service_part[1..]),
385 protocol: String::from(&proto_part[1..]),
386 domain,
387 })
388 }
389
390 pub fn to_service_string(&self) -> String {
392 let mut s = String::from("_");
393 s.push_str(&self.service);
394 s.push_str("._");
395 s.push_str(&self.protocol);
396 s.push('.');
397 s.push_str(&self.domain);
398 s
399 }
400}
401
402#[derive(Debug, Clone, PartialEq, Eq)]
408#[cfg(feature = "alloc")]
409pub struct ServiceInstance {
410 pub instance_name: String,
412 pub service_type: ServiceType,
414 pub port: u16,
416 pub target: String,
418 pub priority: u16,
420 pub weight: u16,
422 pub txt: TxtRecord,
424}
425
426#[cfg(feature = "alloc")]
427impl ServiceInstance {
428 pub fn full_name(&self) -> String {
431 let mut name = self.instance_name.clone();
432 name.push('.');
433 name.push_str(&self.service_type.to_service_string());
434 name
435 }
436
437 pub fn ptr_name(&self) -> String {
439 self.service_type.to_service_string()
440 }
441
442 pub fn to_srv(&self) -> SrvRecord {
444 SrvRecord {
445 priority: self.priority,
446 weight: self.weight,
447 port: self.port,
448 target: self.target.clone(),
449 }
450 }
451
452 pub fn encode_srv_rdata(&self) -> Vec<u8> {
455 let mut buf = Vec::new();
456 buf.extend_from_slice(&self.priority.to_be_bytes());
457 buf.extend_from_slice(&self.weight.to_be_bytes());
458 buf.extend_from_slice(&self.port.to_be_bytes());
459 buf.extend_from_slice(&encode_dns_name(&self.target));
461 buf
462 }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq)]
471#[cfg(feature = "alloc")]
472pub struct CacheEntry {
473 pub record: MdnsRecord,
475 pub inserted_tick: u64,
477 pub original_ttl: u32,
479}
480
481#[cfg(feature = "alloc")]
482impl CacheEntry {
483 pub fn is_expired(&self, current_tick: u64, ticks_per_sec: u64) -> bool {
486 if ticks_per_sec == 0 {
487 return false;
488 }
489 let elapsed_ticks = current_tick.saturating_sub(self.inserted_tick);
490 let elapsed_secs = elapsed_ticks / ticks_per_sec;
491 elapsed_secs >= self.original_ttl as u64
492 }
493
494 pub fn remaining_ttl(&self, current_tick: u64, ticks_per_sec: u64) -> u32 {
496 if ticks_per_sec == 0 {
497 return self.original_ttl;
498 }
499 let elapsed_ticks = current_tick.saturating_sub(self.inserted_tick);
500 let elapsed_secs = elapsed_ticks / ticks_per_sec;
501 self.original_ttl.saturating_sub(elapsed_secs as u32)
502 }
503}
504
505#[derive(Debug, Clone, Copy, PartialEq, Eq)]
511pub enum ProbeState {
512 Idle,
514 Probing { sent: u8, next_tick: u64 },
516 Announcing { sent: u8, next_tick: u64 },
518 Claimed,
520 Conflict,
522}
523
524impl ProbeState {
525 pub fn is_claimed(&self) -> bool {
527 matches!(self, ProbeState::Claimed)
528 }
529
530 pub fn is_conflict(&self) -> bool {
532 matches!(self, ProbeState::Conflict)
533 }
534}
535
536#[cfg(feature = "alloc")]
542pub struct MdnsResponder {
543 hostname: String,
545 services: Vec<ServiceInstance>,
547 cache: BTreeMap<String, Vec<CacheEntry>>,
549 probe_states: BTreeMap<String, ProbeState>,
551 ticks_per_sec: u64,
553 host_ipv4: [u8; 4],
555 host_ipv6: [u8; 16],
557}
558
559#[cfg(feature = "alloc")]
560impl MdnsResponder {
561 pub fn new(hostname: &str, ticks_per_sec: u64) -> Self {
563 Self {
564 hostname: String::from(hostname),
565 services: Vec::new(),
566 cache: BTreeMap::new(),
567 probe_states: BTreeMap::new(),
568 ticks_per_sec,
569 host_ipv4: [0; 4],
570 host_ipv6: [0; 16],
571 }
572 }
573
574 pub fn set_ipv4(&mut self, addr: [u8; 4]) {
576 self.host_ipv4 = addr;
577 }
578
579 pub fn set_ipv6(&mut self, addr: [u8; 16]) {
581 self.host_ipv6 = addr;
582 }
583
584 pub fn fqdn(&self) -> String {
586 let mut name = self.hostname.clone();
587 name.push_str(LOCAL_SUFFIX);
588 name
589 }
590
591 pub fn register_service(&mut self, svc: ServiceInstance) -> Result<(), MdnsError> {
595 if self.services.len() >= MAX_SERVICES {
596 return Err(MdnsError::TooManyServices);
597 }
598 let full_name = svc.full_name();
600 self.probe_states.insert(full_name, ProbeState::Idle);
601 self.services.push(svc);
602 Ok(())
603 }
604
605 pub fn deregister_service(&mut self, instance_name: &str) -> Vec<MdnsRecord> {
607 let mut goodbyes = Vec::new();
608 let mut names_to_remove = Vec::new();
609
610 self.services.retain(|svc| {
611 if svc.instance_name == instance_name {
612 goodbyes.push(MdnsRecord {
614 name: svc.ptr_name(),
615 rtype: MdnsRecordType::PTR,
616 cache_flush: false,
617 ttl: 0,
618 rdata: encode_dns_name(&svc.full_name()),
619 });
620 goodbyes.push(MdnsRecord {
621 name: svc.full_name(),
622 rtype: MdnsRecordType::SRV,
623 cache_flush: true,
624 ttl: 0,
625 rdata: svc.encode_srv_rdata(),
626 });
627 goodbyes.push(MdnsRecord {
628 name: svc.full_name(),
629 rtype: MdnsRecordType::TXT,
630 cache_flush: true,
631 ttl: 0,
632 rdata: svc.txt.encode(),
633 });
634 names_to_remove.push(svc.full_name());
635 false } else {
637 true }
639 });
640
641 for name in &names_to_remove {
642 self.probe_states.remove(name);
643 }
644
645 goodbyes
646 }
647
648 pub fn browse_services(&self, service_type: &str) -> Vec<&ServiceInstance> {
650 let mut results = Vec::new();
652 for svc in &self.services {
653 if svc.service_type.to_service_string() == service_type {
654 results.push(svc);
655 }
656 }
657 results
658 }
659
660 pub fn lookup_service(&self, full_name: &str) -> Option<&ServiceInstance> {
662 self.services
663 .iter()
664 .find(|svc| svc.full_name() == full_name)
665 }
666
667 pub fn start_probe(&mut self, name: &str, current_tick: u64) {
671 let interval_ticks = (PROBE_INTERVAL_MS * self.ticks_per_sec) / 1000;
672 self.probe_states.insert(
673 String::from(name),
674 ProbeState::Probing {
675 sent: 1, next_tick: current_tick + interval_ticks,
677 },
678 );
679 }
680
681 pub fn tick_probes(&mut self, current_tick: u64) -> Vec<(String, ProbeState)> {
684 let mut actions = Vec::new();
685 let probe_interval = (PROBE_INTERVAL_MS * self.ticks_per_sec) / 1000;
686 let announce_interval = (ANNOUNCE_INTERVAL_MS * self.ticks_per_sec) / 1000;
687
688 for (name, state) in self.probe_states.iter_mut() {
689 match *state {
690 ProbeState::Probing { sent, next_tick } if current_tick >= next_tick => {
691 if sent + 1 >= PROBE_COUNT {
692 *state = ProbeState::Announcing {
694 sent: 0,
695 next_tick: current_tick + announce_interval,
696 };
697 } else {
698 *state = ProbeState::Probing {
699 sent: sent + 1,
700 next_tick: current_tick + probe_interval,
701 };
702 }
703 actions.push((name.clone(), *state));
704 }
705 ProbeState::Announcing { sent, next_tick } if current_tick >= next_tick => {
706 if sent + 1 >= ANNOUNCE_COUNT {
707 *state = ProbeState::Claimed;
708 } else {
709 *state = ProbeState::Announcing {
710 sent: sent + 1,
711 next_tick: current_tick + announce_interval,
712 };
713 }
714 actions.push((name.clone(), *state));
715 }
716 _ => {}
717 }
718 }
719 actions
720 }
721
722 pub fn mark_conflict(&mut self, name: &str) {
724 if let Some(state) = self.probe_states.get_mut(name) {
725 *state = ProbeState::Conflict;
726 }
727 }
728
729 pub fn has_answer(&self, question: &MdnsQuestion) -> bool {
733 let qname = &question.name;
734 let fqdn = self.fqdn();
735
736 match question.qtype {
737 MdnsRecordType::A | MdnsRecordType::ANY => {
738 if qname == &fqdn {
739 return true;
740 }
741 }
742 MdnsRecordType::AAAA => {
743 if qname == &fqdn {
744 return true;
745 }
746 }
747 MdnsRecordType::PTR => {
748 for svc in &self.services {
749 if qname == &svc.ptr_name() {
750 return true;
751 }
752 }
753 }
754 MdnsRecordType::SRV | MdnsRecordType::TXT => {
755 for svc in &self.services {
756 if qname == &svc.full_name() {
757 return true;
758 }
759 }
760 }
761 _ => {}
762 }
763 false
764 }
765
766 pub fn answer(&self, question: &MdnsQuestion) -> Vec<MdnsRecord> {
768 let mut answers = Vec::new();
769 let qname = &question.name;
770 let fqdn = self.fqdn();
771
772 if (question.qtype == MdnsRecordType::A || question.qtype == MdnsRecordType::ANY)
774 && qname == &fqdn
775 {
776 answers.push(MdnsRecord {
777 name: fqdn.clone(),
778 rtype: MdnsRecordType::A,
779 cache_flush: true,
780 ttl: TTL_UNIQUE,
781 rdata: self.host_ipv4.to_vec(),
782 });
783 }
784
785 if (question.qtype == MdnsRecordType::AAAA || question.qtype == MdnsRecordType::ANY)
787 && qname == &fqdn
788 {
789 answers.push(MdnsRecord {
790 name: fqdn.clone(),
791 rtype: MdnsRecordType::AAAA,
792 cache_flush: true,
793 ttl: TTL_UNIQUE,
794 rdata: self.host_ipv6.to_vec(),
795 });
796 }
797
798 if question.qtype == MdnsRecordType::PTR || question.qtype == MdnsRecordType::ANY {
800 for svc in &self.services {
801 if qname == &svc.ptr_name() {
802 answers.push(MdnsRecord {
803 name: svc.ptr_name(),
804 rtype: MdnsRecordType::PTR,
805 cache_flush: false,
806 ttl: TTL_SHARED,
807 rdata: encode_dns_name(&svc.full_name()),
808 });
809 }
810 }
811 }
812
813 if question.qtype == MdnsRecordType::SRV || question.qtype == MdnsRecordType::ANY {
815 for svc in &self.services {
816 if qname == &svc.full_name() {
817 answers.push(MdnsRecord {
818 name: svc.full_name(),
819 rtype: MdnsRecordType::SRV,
820 cache_flush: true,
821 ttl: TTL_UNIQUE,
822 rdata: svc.encode_srv_rdata(),
823 });
824 }
825 }
826 }
827
828 if question.qtype == MdnsRecordType::TXT || question.qtype == MdnsRecordType::ANY {
830 for svc in &self.services {
831 if qname == &svc.full_name() {
832 answers.push(MdnsRecord {
833 name: svc.full_name(),
834 rtype: MdnsRecordType::TXT,
835 cache_flush: true,
836 ttl: TTL_UNIQUE,
837 rdata: svc.txt.encode(),
838 });
839 }
840 }
841 }
842
843 answers
844 }
845
846 pub fn suppress_known_answers(
852 &self,
853 answers: &mut Vec<MdnsRecord>,
854 known_answers: &[MdnsRecord],
855 ) {
856 answers.retain(|answer| {
857 !known_answers.iter().any(|ka| {
858 ka.name == answer.name
859 && ka.rtype == answer.rtype
860 && ka.rdata == answer.rdata
861 && ka.ttl >= answer.ttl / 2
862 })
863 });
864 }
865
866 pub fn detect_conflict(&self, record: &MdnsRecord) -> bool {
870 let fqdn = self.fqdn();
871
872 if record.name == fqdn
874 && (record.rtype == MdnsRecordType::A || record.rtype == MdnsRecordType::AAAA)
875 && record.rdata != self.host_ipv4.to_vec()
876 && record.rdata != self.host_ipv6.to_vec()
877 {
878 return true;
879 }
880
881 for svc in &self.services {
883 if record.name == svc.full_name()
884 && record.rtype == MdnsRecordType::SRV
885 && record.rdata != svc.encode_srv_rdata()
886 {
887 return true;
888 }
889 }
890
891 false
892 }
893
894 pub fn cache_insert(&mut self, record: MdnsRecord, current_tick: u64) -> Result<(), MdnsError> {
898 if record.ttl == 0 {
900 self.cache_remove(&record.name, record.rtype);
901 return Ok(());
902 }
903
904 let ttl = record.ttl;
905 let entry = CacheEntry {
906 original_ttl: ttl,
907 inserted_tick: current_tick,
908 record,
909 };
910
911 let total: usize = self.cache.values().map(|v| v.len()).sum();
913
914 let entries = self.cache.entry(entry.record.name.clone()).or_default();
915
916 if entry.record.cache_flush {
918 entries.retain(|e| e.record.rtype != entry.record.rtype);
919 }
920
921 if let Some(existing) = entries
923 .iter_mut()
924 .find(|e| e.record.rtype == entry.record.rtype && e.record.rdata == entry.record.rdata)
925 {
926 *existing = entry;
927 } else {
928 if total >= MAX_CACHE_ENTRIES {
929 return Err(MdnsError::CacheFull);
930 }
931 entries.push(entry);
932 }
933
934 Ok(())
935 }
936
937 pub fn cache_remove(&mut self, name: &str, rtype: MdnsRecordType) {
939 if let Some(entries) = self.cache.get_mut(name) {
940 entries.retain(|e| e.record.rtype != rtype);
941 if entries.is_empty() {
942 self.cache.remove(name);
943 }
944 }
945 }
946
947 pub fn cache_lookup(
949 &self,
950 name: &str,
951 rtype: MdnsRecordType,
952 current_tick: u64,
953 ) -> Vec<&CacheEntry> {
954 match self.cache.get(name) {
955 Some(entries) => entries
956 .iter()
957 .filter(|e| {
958 (rtype == MdnsRecordType::ANY || e.record.rtype == rtype)
959 && !e.is_expired(current_tick, self.ticks_per_sec)
960 })
961 .collect(),
962 None => Vec::new(),
963 }
964 }
965
966 pub fn cache_evict_expired(&mut self, current_tick: u64) {
968 let tps = self.ticks_per_sec;
969 self.cache.retain(|_name, entries| {
970 entries.retain(|e| !e.is_expired(current_tick, tps));
971 !entries.is_empty()
972 });
973 }
974
975 pub fn cache_size(&self) -> usize {
977 self.cache.values().map(|v| v.len()).sum()
978 }
979}
980
981pub fn is_local_name(name: &str) -> bool {
987 let name_trimmed = name.strip_suffix('.').unwrap_or(name);
989 if name_trimmed.len() < 5 {
990 return false;
991 }
992 let bytes = name_trimmed.as_bytes();
994 let len = bytes.len();
995 if len >= 6 {
996 let suffix = &bytes[len - 6..];
997 suffix.eq_ignore_ascii_case(b".local")
998 } else {
999 bytes.eq_ignore_ascii_case(b"local")
1001 }
1002}
1003
1004pub fn is_reverse_lookup(name: &str) -> bool {
1006 let name_trimmed = name.trim_end_matches('.');
1007 name_trimmed.ends_with(".in-addr.arpa") || name_trimmed.ends_with(".ip6.arpa")
1008}
1009
1010#[cfg(feature = "alloc")]
1012pub fn ipv4_reverse_name(addr: [u8; 4]) -> String {
1013 let mut name = String::new();
1014 for i in (0..4).rev() {
1016 if !name.is_empty() {
1017 name.push('.');
1018 }
1019 let mut val = addr[i];
1021 if val >= 100 {
1022 name.push((b'0' + val / 100) as char);
1023 val %= 100;
1024 name.push((b'0' + val / 10) as char);
1025 name.push((b'0' + val % 10) as char);
1026 } else if val >= 10 {
1027 name.push((b'0' + val / 10) as char);
1028 name.push((b'0' + val % 10) as char);
1029 } else {
1030 name.push((b'0' + val) as char);
1031 }
1032 }
1033 name.push_str(".in-addr.arpa");
1034 name
1035}
1036
1037#[cfg(feature = "alloc")]
1043pub fn encode_dns_name(name: &str) -> Vec<u8> {
1044 let mut buf = Vec::new();
1045 let name = name.trim_end_matches('.');
1046 if name.is_empty() {
1047 buf.push(0);
1048 return buf;
1049 }
1050 for label in name.split('.') {
1051 let len = label.len();
1052 if len > MAX_LABEL_LEN {
1053 buf.push(MAX_LABEL_LEN as u8);
1055 buf.extend_from_slice(&label.as_bytes()[..MAX_LABEL_LEN]);
1056 } else {
1057 buf.push(len as u8);
1058 buf.extend_from_slice(label.as_bytes());
1059 }
1060 }
1061 buf.push(0); buf
1063}
1064
1065#[cfg(feature = "alloc")]
1068pub fn decode_dns_name(buf: &[u8], offset: usize) -> Result<(String, usize), MdnsError> {
1069 let mut name = String::new();
1070 let mut pos = offset;
1071 let mut consumed = 0;
1072 let mut followed_pointer = false;
1073 let mut jumps = 0;
1074
1075 loop {
1076 if pos >= buf.len() {
1077 return Err(MdnsError::MessageTooShort);
1078 }
1079 let len = buf[pos] as usize;
1080
1081 if len == 0 {
1082 if !followed_pointer {
1083 consumed = pos - offset + 1;
1084 }
1085 break;
1086 }
1087
1088 if len & 0xC0 == 0xC0 {
1090 if pos + 1 >= buf.len() {
1091 return Err(MdnsError::MessageTooShort);
1092 }
1093 if !followed_pointer {
1094 consumed = pos - offset + 2;
1095 followed_pointer = true;
1096 }
1097 let ptr = ((len & 0x3F) << 8) | (buf[pos + 1] as usize);
1098 pos = ptr;
1099 jumps += 1;
1100 if jumps > MAX_NAME_LEN {
1101 return Err(MdnsError::InvalidLabel);
1102 }
1103 continue;
1104 }
1105
1106 if len > MAX_LABEL_LEN {
1107 return Err(MdnsError::InvalidLabel);
1108 }
1109 pos += 1;
1110 if pos + len > buf.len() {
1111 return Err(MdnsError::MessageTooShort);
1112 }
1113
1114 if !name.is_empty() {
1115 name.push('.');
1116 }
1117 let label =
1118 core::str::from_utf8(&buf[pos..pos + len]).map_err(|_| MdnsError::InvalidLabel)?;
1119 name.push_str(label);
1120 pos += len;
1121
1122 if name.len() > MAX_NAME_LEN {
1123 return Err(MdnsError::NameTooLong);
1124 }
1125 }
1126
1127 Ok((name, consumed))
1128}
1129
1130#[cfg(feature = "alloc")]
1132pub fn build_probe_query(name: &str) -> Vec<u8> {
1133 let mut buf = Vec::with_capacity(MAX_MDNS_MSG_SIZE);
1134
1135 buf.extend_from_slice(&[0u8; 2]); buf.extend_from_slice(&[0x00, 0x00]); buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&[0u8; 2]); buf.extend_from_slice(&[0u8; 2]); buf.extend_from_slice(&[0u8; 2]); buf.extend_from_slice(&encode_dns_name(name));
1145 buf.extend_from_slice(&MdnsRecordType::ANY.to_u16().to_be_bytes());
1146 let class_qu = MdnsClass::IN as u16 | QU_BIT;
1147 buf.extend_from_slice(&class_qu.to_be_bytes());
1148
1149 buf
1150}
1151
1152trait EqIgnoreAsciiCase {
1157 fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool;
1158}
1159
1160impl EqIgnoreAsciiCase for [u8] {
1161 fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool {
1162 if self.len() != other.len() {
1163 return false;
1164 }
1165 for i in 0..self.len() {
1166 if !self[i].eq_ignore_ascii_case(&other[i]) {
1167 return false;
1168 }
1169 }
1170 true
1171 }
1172}
1173
1174impl EqIgnoreAsciiCase for str {
1175 fn eq_ignore_ascii_case(&self, other: &[u8]) -> bool {
1176 self.as_bytes().eq_ignore_ascii_case(other)
1177 }
1178}
1179
1180#[cfg(test)]
1185mod tests {
1186 #[allow(unused_imports)]
1187 use alloc::vec;
1188
1189 use super::*;
1190
1191 #[test]
1194 fn test_parse_http_tcp_local() {
1195 let st = ServiceType::parse("_http._tcp.local").unwrap();
1196 assert_eq!(st.service, "http");
1197 assert_eq!(st.protocol, "tcp");
1198 assert_eq!(st.domain, "local");
1199 }
1200
1201 #[test]
1202 fn test_parse_ssh_tcp_local() {
1203 let st = ServiceType::parse("_ssh._tcp.local").unwrap();
1204 assert_eq!(st.service, "ssh");
1205 assert_eq!(st.protocol, "tcp");
1206 assert_eq!(st.domain, "local");
1207 }
1208
1209 #[test]
1210 fn test_parse_udp_service() {
1211 let st = ServiceType::parse("_tftp._udp.local").unwrap();
1212 assert_eq!(st.protocol, "udp");
1213 }
1214
1215 #[test]
1216 fn test_parse_invalid_no_underscore() {
1217 assert_eq!(
1218 ServiceType::parse("http._tcp.local"),
1219 Err(MdnsError::InvalidServiceType)
1220 );
1221 }
1222
1223 #[test]
1224 fn test_parse_invalid_protocol() {
1225 assert_eq!(
1226 ServiceType::parse("_http._sctp.local"),
1227 Err(MdnsError::InvalidServiceType)
1228 );
1229 }
1230
1231 #[test]
1232 fn test_service_type_roundtrip() {
1233 let st = ServiceType::parse("_http._tcp.local").unwrap();
1234 assert_eq!(st.to_service_string(), "_http._tcp.local");
1235 }
1236
1237 #[test]
1240 fn test_txt_encode_decode() {
1241 let mut txt = TxtRecord::new();
1242 txt.add("path", "/index.html").unwrap();
1243 txt.add("version", "1").unwrap();
1244
1245 let encoded = txt.encode();
1246 let decoded = TxtRecord::decode(&encoded).unwrap();
1247
1248 assert_eq!(decoded.entries.len(), 2);
1249 assert_eq!(decoded.get("path"), Some("/index.html"));
1250 assert_eq!(decoded.get("version"), Some("1"));
1251 }
1252
1253 #[test]
1254 fn test_txt_empty() {
1255 let txt = TxtRecord::new();
1256 let encoded = txt.encode();
1257 assert_eq!(encoded, vec![0u8]); }
1259
1260 #[test]
1261 fn test_txt_boolean_key() {
1262 let mut txt = TxtRecord::new();
1263 txt.add("paper", "").unwrap();
1264 let encoded = txt.encode();
1265 let decoded = TxtRecord::decode(&encoded).unwrap();
1266 assert_eq!(decoded.entries[0].key, "paper");
1267 assert_eq!(decoded.entries[0].value, "");
1268 }
1269
1270 #[test]
1273 fn test_srv_record_construction() {
1274 let svc = ServiceInstance {
1275 instance_name: String::from("My Web Server"),
1276 service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1277 port: 8080,
1278 target: String::from("myhost.local"),
1279 priority: 0,
1280 weight: 0,
1281 txt: TxtRecord::new(),
1282 };
1283 let srv = svc.to_srv();
1284 assert_eq!(srv.port, 8080);
1285 assert_eq!(srv.priority, 0);
1286 assert_eq!(srv.target, "myhost.local");
1287 }
1288
1289 #[test]
1290 fn test_srv_rdata_encoding() {
1291 let svc = ServiceInstance {
1292 instance_name: String::from("Test"),
1293 service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1294 port: 80,
1295 target: String::from("host.local"),
1296 priority: 10,
1297 weight: 20,
1298 txt: TxtRecord::new(),
1299 };
1300 let rdata = svc.encode_srv_rdata();
1301 assert_eq!(rdata[0], 0);
1303 assert_eq!(rdata[1], 10); assert_eq!(rdata[2], 0);
1305 assert_eq!(rdata[3], 20); assert_eq!(rdata[4], 0);
1307 assert_eq!(rdata[5], 80); assert_eq!(rdata[6], 4); }
1311
1312 #[test]
1315 fn test_ptr_name_for_service() {
1316 let svc = ServiceInstance {
1317 instance_name: String::from("My Printer"),
1318 service_type: ServiceType::parse("_ipp._tcp.local").unwrap(),
1319 port: 631,
1320 target: String::from("printer.local"),
1321 priority: 0,
1322 weight: 0,
1323 txt: TxtRecord::new(),
1324 };
1325 assert_eq!(svc.ptr_name(), "_ipp._tcp.local");
1326 assert_eq!(svc.full_name(), "My Printer._ipp._tcp.local");
1327 }
1328
1329 #[test]
1332 fn test_conflict_same_name_different_data() {
1333 let mut resp = MdnsResponder::new("myhost", 1000);
1334 resp.set_ipv4([192, 168, 1, 10]);
1335
1336 let conflicting = MdnsRecord {
1337 name: String::from("myhost.local"),
1338 rtype: MdnsRecordType::A,
1339 cache_flush: true,
1340 ttl: TTL_UNIQUE,
1341 rdata: vec![192, 168, 1, 99], };
1343 assert!(resp.detect_conflict(&conflicting));
1344 }
1345
1346 #[test]
1347 fn test_no_conflict_same_data() {
1348 let mut resp = MdnsResponder::new("myhost", 1000);
1349 resp.set_ipv4([192, 168, 1, 10]);
1350
1351 let same = MdnsRecord {
1352 name: String::from("myhost.local"),
1353 rtype: MdnsRecordType::A,
1354 cache_flush: true,
1355 ttl: TTL_UNIQUE,
1356 rdata: vec![192, 168, 1, 10], };
1358 assert!(!resp.detect_conflict(&same));
1359 }
1360
1361 #[test]
1364 fn test_known_answer_suppression() {
1365 let resp = MdnsResponder::new("myhost", 1000);
1366
1367 let mut answers = vec![MdnsRecord {
1368 name: String::from("_http._tcp.local"),
1369 rtype: MdnsRecordType::PTR,
1370 cache_flush: false,
1371 ttl: TTL_SHARED,
1372 rdata: encode_dns_name("Server._http._tcp.local"),
1373 }];
1374
1375 let known = vec![MdnsRecord {
1376 name: String::from("_http._tcp.local"),
1377 rtype: MdnsRecordType::PTR,
1378 cache_flush: false,
1379 ttl: TTL_SHARED, rdata: encode_dns_name("Server._http._tcp.local"),
1381 }];
1382
1383 resp.suppress_known_answers(&mut answers, &known);
1384 assert!(answers.is_empty(), "Answer should be suppressed");
1385 }
1386
1387 #[test]
1388 fn test_known_answer_not_suppressed_low_ttl() {
1389 let resp = MdnsResponder::new("myhost", 1000);
1390
1391 let mut answers = vec![MdnsRecord {
1392 name: String::from("_http._tcp.local"),
1393 rtype: MdnsRecordType::PTR,
1394 cache_flush: false,
1395 ttl: TTL_SHARED,
1396 rdata: encode_dns_name("Server._http._tcp.local"),
1397 }];
1398
1399 let known = vec![MdnsRecord {
1400 name: String::from("_http._tcp.local"),
1401 rtype: MdnsRecordType::PTR,
1402 cache_flush: false,
1403 ttl: 1, rdata: encode_dns_name("Server._http._tcp.local"),
1405 }];
1406
1407 resp.suppress_known_answers(&mut answers, &known);
1408 assert_eq!(
1409 answers.len(),
1410 1,
1411 "Answer should NOT be suppressed with low TTL"
1412 );
1413 }
1414
1415 #[test]
1418 fn test_cache_entry_expiry() {
1419 let entry = CacheEntry {
1420 record: MdnsRecord {
1421 name: String::from("host.local"),
1422 rtype: MdnsRecordType::A,
1423 cache_flush: true,
1424 ttl: 120,
1425 rdata: vec![10, 0, 0, 1],
1426 },
1427 inserted_tick: 0,
1428 original_ttl: 120,
1429 };
1430
1431 assert!(!entry.is_expired(119_000, 1000));
1433 assert_eq!(entry.remaining_ttl(119_000, 1000), 1);
1434
1435 assert!(entry.is_expired(120_000, 1000));
1437 assert_eq!(entry.remaining_ttl(120_000, 1000), 0);
1438 }
1439
1440 #[test]
1443 fn test_goodbye_packets_on_deregister() {
1444 let mut resp = MdnsResponder::new("myhost", 1000);
1445 let mut txt = TxtRecord::new();
1446 txt.add("path", "/").unwrap();
1447
1448 let svc = ServiceInstance {
1449 instance_name: String::from("WebSrv"),
1450 service_type: ServiceType::parse("_http._tcp.local").unwrap(),
1451 port: 80,
1452 target: String::from("myhost.local"),
1453 priority: 0,
1454 weight: 0,
1455 txt,
1456 };
1457 resp.register_service(svc).unwrap();
1458
1459 let goodbyes = resp.deregister_service("WebSrv");
1460 assert_eq!(goodbyes.len(), 3); for g in &goodbyes {
1462 assert_eq!(g.ttl, 0, "Goodbye packets must have TTL=0");
1463 }
1464 assert_eq!(goodbyes[0].rtype, MdnsRecordType::PTR);
1465 assert_eq!(goodbyes[1].rtype, MdnsRecordType::SRV);
1466 assert_eq!(goodbyes[2].rtype, MdnsRecordType::TXT);
1467 }
1468
1469 #[test]
1472 fn test_register_and_lookup_service() {
1473 let mut resp = MdnsResponder::new("myhost", 1000);
1474 let svc = ServiceInstance {
1475 instance_name: String::from("My SSH"),
1476 service_type: ServiceType::parse("_ssh._tcp.local").unwrap(),
1477 port: 22,
1478 target: String::from("myhost.local"),
1479 priority: 0,
1480 weight: 0,
1481 txt: TxtRecord::new(),
1482 };
1483 resp.register_service(svc).unwrap();
1484
1485 let found = resp.browse_services("_ssh._tcp.local");
1486 assert_eq!(found.len(), 1);
1487 assert_eq!(found[0].port, 22);
1488
1489 let by_name = resp.lookup_service("My SSH._ssh._tcp.local");
1490 assert!(by_name.is_some());
1491 }
1492
1493 #[test]
1496 fn test_is_local_name() {
1497 assert!(is_local_name("myhost.local"));
1498 assert!(is_local_name("myhost.local."));
1499 assert!(is_local_name("sub.myhost.local"));
1500 assert!(!is_local_name("myhost.com"));
1501 assert!(!is_local_name("myhost.localhost"));
1502 }
1503
1504 #[test]
1507 fn test_build_probe_query() {
1508 let query = build_probe_query("myhost.local");
1509 assert_eq!(query[0], 0); assert_eq!(query[1], 0); assert_eq!(query[4], 0); assert_eq!(query[5], 1); assert_eq!(query[6], 0); assert_eq!(query[7], 0); assert_eq!(query[12], 6);
1520 assert_eq!(&query[13..19], b"myhost");
1521 assert_eq!(query[19], 5);
1523 assert_eq!(&query[20..25], b"local");
1524 assert_eq!(query[25], 0);
1526 assert_eq!(query[26], 0);
1528 assert_eq!(query[27], 255);
1529 assert_eq!(query[28], 0x80);
1531 assert_eq!(query[29], 0x01);
1532 }
1533
1534 #[test]
1537 fn test_encode_dns_name() {
1538 let encoded = encode_dns_name("myhost.local");
1539 assert_eq!(encoded[0], 6); assert_eq!(&encoded[1..7], b"myhost");
1541 assert_eq!(encoded[7], 5); assert_eq!(&encoded[8..13], b"local");
1543 assert_eq!(encoded[13], 0); }
1545
1546 #[test]
1547 fn test_decode_dns_name() {
1548 let encoded = encode_dns_name("test.local");
1549 let (name, consumed) = decode_dns_name(&encoded, 0).unwrap();
1550 assert_eq!(name, "test.local");
1551 assert_eq!(consumed, encoded.len());
1552 }
1553
1554 #[test]
1557 fn test_is_reverse_lookup() {
1558 assert!(is_reverse_lookup("1.168.192.in-addr.arpa"));
1559 assert!(is_reverse_lookup("8.b.d.0.1.0.0.2.ip6.arpa"));
1560 assert!(!is_reverse_lookup("myhost.local"));
1561 }
1562
1563 #[test]
1564 fn test_ipv4_reverse_name() {
1565 let name = ipv4_reverse_name([192, 168, 1, 10]);
1566 assert_eq!(name, "10.1.168.192.in-addr.arpa");
1567 }
1568
1569 #[test]
1572 fn test_cache_insert_and_lookup() {
1573 let mut resp = MdnsResponder::new("myhost", 1000);
1574 let record = MdnsRecord {
1575 name: String::from("other.local"),
1576 rtype: MdnsRecordType::A,
1577 cache_flush: false,
1578 ttl: 120,
1579 rdata: vec![10, 0, 0, 2],
1580 };
1581 resp.cache_insert(record, 0).unwrap();
1582
1583 let results = resp.cache_lookup("other.local", MdnsRecordType::A, 0);
1584 assert_eq!(results.len(), 1);
1585 assert_eq!(results[0].record.rdata, vec![10, 0, 0, 2]);
1586 }
1587
1588 #[test]
1589 fn test_cache_goodbye_removes() {
1590 let mut resp = MdnsResponder::new("myhost", 1000);
1591 let record = MdnsRecord {
1592 name: String::from("gone.local"),
1593 rtype: MdnsRecordType::A,
1594 cache_flush: false,
1595 ttl: 120,
1596 rdata: vec![10, 0, 0, 3],
1597 };
1598 resp.cache_insert(record, 0).unwrap();
1599 assert_eq!(resp.cache_size(), 1);
1600
1601 let goodbye = MdnsRecord {
1603 name: String::from("gone.local"),
1604 rtype: MdnsRecordType::A,
1605 cache_flush: false,
1606 ttl: 0,
1607 rdata: vec![10, 0, 0, 3],
1608 };
1609 resp.cache_insert(goodbye, 1000).unwrap();
1610 assert_eq!(resp.cache_size(), 0);
1611 }
1612
1613 #[test]
1616 fn test_probe_state_machine() {
1617 let mut resp = MdnsResponder::new("myhost", 1000);
1618 resp.start_probe("myhost.local", 0);
1619
1620 let actions = resp.tick_probes(250);
1622 assert_eq!(actions.len(), 1);
1623
1624 let actions = resp.tick_probes(500);
1625 assert_eq!(actions.len(), 1);
1626 match actions[0].1 {
1628 ProbeState::Announcing { .. } => {}
1629 _ => panic!("Expected Announcing state after 3 probes"),
1630 }
1631 }
1632
1633 #[test]
1636 fn test_answer_a_record() {
1637 let mut resp = MdnsResponder::new("myhost", 1000);
1638 resp.set_ipv4([192, 168, 1, 42]);
1639
1640 let q = MdnsQuestion {
1641 name: String::from("myhost.local"),
1642 qtype: MdnsRecordType::A,
1643 unicast: false,
1644 };
1645
1646 assert!(resp.has_answer(&q));
1647 let answers = resp.answer(&q);
1648 assert_eq!(answers.len(), 1);
1649 assert_eq!(answers[0].rtype, MdnsRecordType::A);
1650 assert_eq!(answers[0].rdata, vec![192, 168, 1, 42]);
1651 assert_eq!(answers[0].ttl, TTL_UNIQUE);
1652 }
1653}