1#![allow(dead_code)]
7
8#[cfg(feature = "alloc")]
9use alloc::{collections::BTreeMap, string::String, vec, vec::Vec};
10
11use spin::Mutex;
12
13use super::Ipv4Address;
14
15pub const DNS_PORT: u16 = 53;
21
22const MAX_DNS_MSG_SIZE: usize = 512;
24
25const MAX_LABEL_LEN: usize = 63;
27
28const MAX_NAME_LEN: usize = 255;
30
31const MAX_CACHE_ENTRIES: usize = 256;
33
34const MAX_NAMESERVERS: usize = 3;
36
37const LABEL_POINTER_MASK: u8 = 0xC0;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
46#[repr(u16)]
47pub enum DnsRecordType {
48 A = 1,
50 NS = 2,
52 CNAME = 5,
54 SOA = 6,
56 PTR = 12,
58 MX = 15,
60 TXT = 16,
62 AAAA = 28,
64 SRV = 33,
66 Unknown = 0,
68}
69
70impl DnsRecordType {
71 pub fn from_u16(val: u16) -> Self {
72 match val {
73 1 => Self::A,
74 2 => Self::NS,
75 5 => Self::CNAME,
76 6 => Self::SOA,
77 12 => Self::PTR,
78 15 => Self::MX,
79 16 => Self::TXT,
80 28 => Self::AAAA,
81 33 => Self::SRV,
82 _ => Self::Unknown,
83 }
84 }
85
86 pub fn to_u16(self) -> u16 {
87 self as u16
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93#[repr(u16)]
94pub enum DnsClass {
95 IN = 1,
97 ANY = 255,
99}
100
101impl DnsClass {
102 pub fn from_u16(val: u16) -> Self {
103 match val {
104 1 => Self::IN,
105 255 => Self::ANY,
106 _ => Self::IN,
107 }
108 }
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum DnsError {
118 MessageTooShort,
120 InvalidLabel,
122 NameTooLong,
124 CompressionLoop,
126 BufferTooSmall,
128 ServerError(DnsResponseCode),
130 NoNameservers,
132 Timeout,
134 NotFound,
136 InvalidFormat,
138 CacheFull,
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
144#[repr(u8)]
145pub enum DnsResponseCode {
146 NoError = 0,
147 FormatError = 1,
148 ServerFailure = 2,
149 NameError = 3,
150 NotImplemented = 4,
151 Refused = 5,
152}
153
154impl DnsResponseCode {
155 pub fn from_u8(val: u8) -> Self {
156 match val & 0x0F {
157 0 => Self::NoError,
158 1 => Self::FormatError,
159 2 => Self::ServerFailure,
160 3 => Self::NameError,
161 4 => Self::NotImplemented,
162 5 => Self::Refused,
163 _ => Self::ServerFailure,
164 }
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq)]
174pub struct DnsHeader {
175 pub id: u16,
177 pub qr: bool,
179 pub opcode: u8,
181 pub aa: bool,
183 pub tc: bool,
185 pub rd: bool,
187 pub ra: bool,
189 pub rcode: DnsResponseCode,
191 pub qdcount: u16,
193 pub ancount: u16,
195 pub nscount: u16,
197 pub arcount: u16,
199}
200
201impl DnsHeader {
202 pub const SIZE: usize = 12;
203
204 pub fn new_query(id: u16) -> Self {
206 Self {
207 id,
208 qr: false,
209 opcode: 0,
210 aa: false,
211 tc: false,
212 rd: true,
213 ra: false,
214 rcode: DnsResponseCode::NoError,
215 qdcount: 1,
216 ancount: 0,
217 nscount: 0,
218 arcount: 0,
219 }
220 }
221
222 pub fn to_bytes(&self, buf: &mut [u8]) -> Result<usize, DnsError> {
224 if buf.len() < Self::SIZE {
225 return Err(DnsError::BufferTooSmall);
226 }
227
228 buf[0..2].copy_from_slice(&self.id.to_be_bytes());
229
230 let mut flags: u16 = 0;
231 if self.qr {
232 flags |= 1 << 15;
233 }
234 flags |= (self.opcode as u16 & 0x0F) << 11;
235 if self.aa {
236 flags |= 1 << 10;
237 }
238 if self.tc {
239 flags |= 1 << 9;
240 }
241 if self.rd {
242 flags |= 1 << 8;
243 }
244 if self.ra {
245 flags |= 1 << 7;
246 }
247 flags |= self.rcode as u16 & 0x0F;
248
249 buf[2..4].copy_from_slice(&flags.to_be_bytes());
250 buf[4..6].copy_from_slice(&self.qdcount.to_be_bytes());
251 buf[6..8].copy_from_slice(&self.ancount.to_be_bytes());
252 buf[8..10].copy_from_slice(&self.nscount.to_be_bytes());
253 buf[10..12].copy_from_slice(&self.arcount.to_be_bytes());
254
255 Ok(Self::SIZE)
256 }
257
258 pub fn from_bytes(buf: &[u8]) -> Result<Self, DnsError> {
260 if buf.len() < Self::SIZE {
261 return Err(DnsError::MessageTooShort);
262 }
263
264 let id = u16::from_be_bytes([buf[0], buf[1]]);
265 let flags = u16::from_be_bytes([buf[2], buf[3]]);
266
267 Ok(Self {
268 id,
269 qr: (flags >> 15) & 1 == 1,
270 opcode: ((flags >> 11) & 0x0F) as u8,
271 aa: (flags >> 10) & 1 == 1,
272 tc: (flags >> 9) & 1 == 1,
273 rd: (flags >> 8) & 1 == 1,
274 ra: (flags >> 7) & 1 == 1,
275 rcode: DnsResponseCode::from_u8((flags & 0x0F) as u8),
276 qdcount: u16::from_be_bytes([buf[4], buf[5]]),
277 ancount: u16::from_be_bytes([buf[6], buf[7]]),
278 nscount: u16::from_be_bytes([buf[8], buf[9]]),
279 arcount: u16::from_be_bytes([buf[10], buf[11]]),
280 })
281 }
282}
283
284#[derive(Debug, Clone)]
290#[cfg(feature = "alloc")]
291pub struct DnsQuestion {
292 pub qname: String,
293 pub qtype: DnsRecordType,
294 pub qclass: DnsClass,
295}
296
297#[derive(Debug, Clone)]
303#[cfg(feature = "alloc")]
304pub enum DnsRecordData {
305 A(Ipv4Address),
307 AAAA([u8; 16]),
309 CNAME(String),
311 MX { preference: u16, exchange: String },
313 TXT(String),
315 PTR(String),
317 SRV {
319 priority: u16,
320 weight: u16,
321 port: u16,
322 target: String,
323 },
324 NS(String),
326 Raw(Vec<u8>),
328}
329
330#[derive(Debug, Clone)]
332#[cfg(feature = "alloc")]
333pub struct DnsRecord {
334 pub name: String,
335 pub rtype: DnsRecordType,
336 pub rclass: DnsClass,
337 pub ttl: u32,
338 pub data: DnsRecordData,
339}
340
341pub fn encode_name(name: &str, buf: &mut [u8]) -> Result<usize, DnsError> {
350 let mut pos = 0;
351
352 if name.is_empty() {
353 if buf.is_empty() {
354 return Err(DnsError::BufferTooSmall);
355 }
356 buf[0] = 0;
357 return Ok(1);
358 }
359
360 for label in name.split('.') {
361 let len = label.len();
362 if len == 0 {
363 continue;
364 }
365 if len > MAX_LABEL_LEN {
366 return Err(DnsError::InvalidLabel);
367 }
368 if pos + 1 + len >= buf.len() {
369 return Err(DnsError::BufferTooSmall);
370 }
371 buf[pos] = len as u8;
372 pos += 1;
373 buf[pos..pos + len].copy_from_slice(label.as_bytes());
374 pos += len;
375 }
376
377 if pos >= MAX_NAME_LEN {
378 return Err(DnsError::NameTooLong);
379 }
380
381 if pos >= buf.len() {
382 return Err(DnsError::BufferTooSmall);
383 }
384 buf[pos] = 0; pos += 1;
386
387 Ok(pos)
388}
389
390#[cfg(feature = "alloc")]
395pub fn decode_name(msg: &[u8], offset: usize) -> Result<(String, usize), DnsError> {
396 let mut name = String::new();
397 let mut pos = offset;
398 let mut consumed = 0;
399 let mut followed_pointer = false;
400 let mut jumps = 0;
401 const MAX_JUMPS: usize = 16;
402
403 loop {
404 if pos >= msg.len() {
405 return Err(DnsError::MessageTooShort);
406 }
407
408 let len_byte = msg[pos];
409
410 if len_byte == 0 {
411 if !followed_pointer {
413 consumed = pos - offset + 1;
414 }
415 break;
416 }
417
418 if len_byte & LABEL_POINTER_MASK == LABEL_POINTER_MASK {
419 if pos + 1 >= msg.len() {
421 return Err(DnsError::MessageTooShort);
422 }
423 if !followed_pointer {
424 consumed = pos - offset + 2;
425 followed_pointer = true;
426 }
427 let ptr_offset =
428 (((len_byte & !LABEL_POINTER_MASK) as usize) << 8) | (msg[pos + 1] as usize);
429 jumps += 1;
430 if jumps > MAX_JUMPS {
431 return Err(DnsError::CompressionLoop);
432 }
433 pos = ptr_offset;
434 continue;
435 }
436
437 let label_len = len_byte as usize;
439 pos += 1;
440 if pos + label_len > msg.len() {
441 return Err(DnsError::MessageTooShort);
442 }
443
444 if !name.is_empty() {
445 name.push('.');
446 }
447 for &b in &msg[pos..pos + label_len] {
448 name.push(b as char);
449 }
450 pos += label_len;
451 }
452
453 if consumed == 0 && !followed_pointer {
454 consumed = 1; }
456
457 Ok((name, consumed))
458}
459
460pub fn build_query(
468 buf: &mut [u8],
469 id: u16,
470 name: &str,
471 rtype: DnsRecordType,
472) -> Result<usize, DnsError> {
473 let header = DnsHeader::new_query(id);
474 let mut pos = header.to_bytes(buf)?;
475
476 pos += encode_name(name, &mut buf[pos..])?;
478
479 if pos + 4 > buf.len() {
481 return Err(DnsError::BufferTooSmall);
482 }
483 buf[pos..pos + 2].copy_from_slice(&rtype.to_u16().to_be_bytes());
484 pos += 2;
485
486 buf[pos..pos + 2].copy_from_slice(&DnsClass::IN.to_u16().to_be_bytes());
488 pos += 2;
489
490 Ok(pos)
491}
492
493impl DnsClass {
494 pub fn to_u16(self) -> u16 {
495 self as u16
496 }
497}
498
499#[cfg(feature = "alloc")]
501pub fn parse_response(msg: &[u8]) -> Result<(DnsHeader, Vec<DnsRecord>), DnsError> {
502 let header = DnsHeader::from_bytes(msg)?;
503
504 if !header.qr {
505 return Err(DnsError::InvalidFormat);
506 }
507 if header.rcode as u8 != 0 {
508 return Err(DnsError::ServerError(header.rcode));
509 }
510
511 let mut pos = DnsHeader::SIZE;
512
513 for _ in 0..header.qdcount {
515 let (_qname, consumed) = decode_name(msg, pos)?;
516 pos += consumed;
517 pos += 4; if pos > msg.len() {
519 return Err(DnsError::MessageTooShort);
520 }
521 }
522
523 let total_rr = header.ancount as usize + header.nscount as usize + header.arcount as usize;
525 let mut records = Vec::with_capacity(total_rr);
526
527 for _ in 0..total_rr {
528 if pos >= msg.len() {
529 break;
530 }
531 let (name, consumed) = decode_name(msg, pos)?;
532 pos += consumed;
533
534 if pos + 10 > msg.len() {
535 return Err(DnsError::MessageTooShort);
536 }
537
538 let rtype = DnsRecordType::from_u16(u16::from_be_bytes([msg[pos], msg[pos + 1]]));
539 let rclass = DnsClass::from_u16(u16::from_be_bytes([msg[pos + 2], msg[pos + 3]]));
540 let ttl = u32::from_be_bytes([msg[pos + 4], msg[pos + 5], msg[pos + 6], msg[pos + 7]]);
541 let rdlength = u16::from_be_bytes([msg[pos + 8], msg[pos + 9]]) as usize;
542 pos += 10;
543
544 if pos + rdlength > msg.len() {
545 return Err(DnsError::MessageTooShort);
546 }
547
548 let data = parse_rdata(msg, pos, rdlength, rtype)?;
549 pos += rdlength;
550
551 records.push(DnsRecord {
552 name,
553 rtype,
554 rclass,
555 ttl,
556 data,
557 });
558 }
559
560 Ok((header, records))
561}
562
563#[cfg(feature = "alloc")]
565fn parse_rdata(
566 msg: &[u8],
567 offset: usize,
568 rdlength: usize,
569 rtype: DnsRecordType,
570) -> Result<DnsRecordData, DnsError> {
571 match rtype {
572 DnsRecordType::A => {
573 if rdlength != 4 {
574 return Err(DnsError::InvalidFormat);
575 }
576 Ok(DnsRecordData::A(Ipv4Address::new(
577 msg[offset],
578 msg[offset + 1],
579 msg[offset + 2],
580 msg[offset + 3],
581 )))
582 }
583 DnsRecordType::AAAA => {
584 if rdlength != 16 {
585 return Err(DnsError::InvalidFormat);
586 }
587 let mut addr = [0u8; 16];
588 addr.copy_from_slice(&msg[offset..offset + 16]);
589 Ok(DnsRecordData::AAAA(addr))
590 }
591 DnsRecordType::CNAME | DnsRecordType::PTR | DnsRecordType::NS => {
592 let (name, _) = decode_name(msg, offset)?;
593 match rtype {
594 DnsRecordType::CNAME => Ok(DnsRecordData::CNAME(name)),
595 DnsRecordType::PTR => Ok(DnsRecordData::PTR(name)),
596 DnsRecordType::NS => Ok(DnsRecordData::NS(name)),
597 _ => unreachable!(),
598 }
599 }
600 DnsRecordType::MX => {
601 if rdlength < 3 {
602 return Err(DnsError::InvalidFormat);
603 }
604 let preference = u16::from_be_bytes([msg[offset], msg[offset + 1]]);
605 let (exchange, _) = decode_name(msg, offset + 2)?;
606 Ok(DnsRecordData::MX {
607 preference,
608 exchange,
609 })
610 }
611 DnsRecordType::TXT => {
612 let mut text = String::new();
614 let mut pos = offset;
615 let end = offset + rdlength;
616 while pos < end {
617 let txt_len = msg[pos] as usize;
618 pos += 1;
619 if pos + txt_len > end {
620 return Err(DnsError::InvalidFormat);
621 }
622 for &b in &msg[pos..pos + txt_len] {
623 text.push(b as char);
624 }
625 pos += txt_len;
626 }
627 Ok(DnsRecordData::TXT(text))
628 }
629 DnsRecordType::SRV => {
630 if rdlength < 7 {
631 return Err(DnsError::InvalidFormat);
632 }
633 let priority = u16::from_be_bytes([msg[offset], msg[offset + 1]]);
634 let weight = u16::from_be_bytes([msg[offset + 2], msg[offset + 3]]);
635 let port = u16::from_be_bytes([msg[offset + 4], msg[offset + 5]]);
636 let (target, _) = decode_name(msg, offset + 6)?;
637 Ok(DnsRecordData::SRV {
638 priority,
639 weight,
640 port,
641 target,
642 })
643 }
644 _ => {
645 let mut raw = vec![0u8; rdlength];
646 raw.copy_from_slice(&msg[offset..offset + rdlength]);
647 Ok(DnsRecordData::Raw(raw))
648 }
649 }
650}
651
652#[cfg(feature = "alloc")]
658#[derive(Debug, Clone)]
659struct CacheEntry {
660 record: DnsRecord,
661 expires_at: u64,
662 last_used: u64,
663}
664
665#[cfg(feature = "alloc")]
667pub struct DnsCache {
668 entries: BTreeMap<(String, u16), Vec<CacheEntry>>,
670 count: usize,
672 max_entries: usize,
674 clock: u64,
676}
677
678#[cfg(feature = "alloc")]
679impl Default for DnsCache {
680 fn default() -> Self {
681 Self {
682 entries: BTreeMap::new(),
683 count: 0,
684 max_entries: MAX_CACHE_ENTRIES,
685 clock: 0,
686 }
687 }
688}
689
690#[cfg(feature = "alloc")]
691impl DnsCache {
692 pub fn new() -> Self {
694 Self::default()
695 }
696
697 pub fn lookup(&mut self, name: &str, rtype: DnsRecordType) -> Option<Vec<DnsRecord>> {
699 self.clock += 1;
700 let now = self.clock;
701
702 let key = (String::from(name), rtype.to_u16());
703 let entries = self.entries.get_mut(&key)?;
704
705 let before_len = entries.len();
707 entries.retain(|e| e.expires_at > now);
708 self.count -= before_len - entries.len();
709
710 if entries.is_empty() {
711 self.entries.remove(&key);
712 return None;
713 }
714
715 for entry in entries.iter_mut() {
717 entry.last_used = now;
718 }
719
720 Some(entries.iter().map(|e| e.record.clone()).collect())
721 }
722
723 pub fn insert(&mut self, name: &str, record: DnsRecord, ttl: u32) {
725 self.clock += 1;
726 let now = self.clock;
727
728 self.evict_expired(now);
730
731 while self.count >= self.max_entries {
733 self.evict_lru();
734 }
735
736 let key = (String::from(name), record.rtype.to_u16());
737 let entry = CacheEntry {
738 record,
739 expires_at: now + ttl as u64,
740 last_used: now,
741 };
742
743 self.entries.entry(key).or_default().push(entry);
744 self.count += 1;
745 }
746
747 pub fn evict_expired(&mut self, now: u64) {
749 let mut keys_to_remove = Vec::new();
750
751 for (key, entries) in self.entries.iter_mut() {
752 let before = entries.len();
753 entries.retain(|e| e.expires_at > now);
754 self.count -= before - entries.len();
755 if entries.is_empty() {
756 keys_to_remove.push(key.clone());
757 }
758 }
759
760 for key in keys_to_remove {
761 self.entries.remove(&key);
762 }
763 }
764
765 fn evict_lru(&mut self) {
767 let mut oldest_key: Option<(String, u16)> = None;
768 let mut oldest_time = u64::MAX;
769
770 for (key, entries) in &self.entries {
771 for entry in entries {
772 if entry.last_used < oldest_time {
773 oldest_time = entry.last_used;
774 oldest_key = Some(key.clone());
775 }
776 }
777 }
778
779 if let Some(key) = oldest_key {
780 if let Some(entries) = self.entries.get_mut(&key) {
781 if let Some(idx) = entries.iter().position(|e| e.last_used == oldest_time) {
783 entries.remove(idx);
784 self.count -= 1;
785 }
786 if entries.is_empty() {
787 self.entries.remove(&key);
788 }
789 }
790 }
791 }
792
793 pub fn len(&self) -> usize {
795 self.count
796 }
797
798 pub fn is_empty(&self) -> bool {
800 self.count == 0
801 }
802
803 pub fn clear(&mut self) {
805 self.entries.clear();
806 self.count = 0;
807 }
808}
809
810#[cfg(feature = "alloc")]
816#[derive(Debug, Clone)]
817pub struct HostEntry {
818 pub name: String,
819 pub addr: Ipv4Address,
820}
821
822#[cfg(feature = "alloc")]
828pub struct DnsResolver {
829 nameservers: Vec<Ipv4Address>,
831 cache: DnsCache,
833 hosts: Vec<HostEntry>,
835 next_id: u16,
837}
838
839#[cfg(feature = "alloc")]
840impl Default for DnsResolver {
841 fn default() -> Self {
842 let mut resolver = Self {
843 nameservers: Vec::new(),
844 cache: DnsCache::new(),
845 hosts: Vec::new(),
846 next_id: 1,
847 };
848
849 resolver.hosts.push(HostEntry {
851 name: String::from("localhost"),
852 addr: Ipv4Address::LOCALHOST,
853 });
854
855 resolver
856 }
857}
858
859#[cfg(feature = "alloc")]
860impl DnsResolver {
861 pub fn new() -> Self {
863 Self::default()
864 }
865
866 pub fn add_nameserver(&mut self, addr: Ipv4Address) {
868 if self.nameservers.len() < MAX_NAMESERVERS {
869 self.nameservers.push(addr);
870 }
871 }
872
873 pub fn add_host(&mut self, name: &str, addr: Ipv4Address) {
875 self.hosts.push(HostEntry {
876 name: String::from(name),
877 addr,
878 });
879 }
880
881 pub fn parse_resolv_conf(&mut self, content: &str) {
883 for line in content.lines() {
884 let line = line.trim();
885 if line.starts_with('#') || line.is_empty() {
886 continue;
887 }
888 if let Some(addr_str) = line.strip_prefix("nameserver") {
889 let addr_str = addr_str.trim();
890 if let Some(addr) = parse_ipv4(addr_str) {
891 self.add_nameserver(addr);
892 }
893 }
894 }
895 }
896
897 pub fn parse_hosts(&mut self, content: &str) {
899 for line in content.lines() {
900 let line = line.trim();
901 if line.starts_with('#') || line.is_empty() {
902 continue;
903 }
904 let mut parts = line.split_whitespace();
906 if let Some(addr_str) = parts.next() {
907 if let Some(addr) = parse_ipv4(addr_str) {
908 for hostname in parts {
909 if hostname.starts_with('#') {
910 break;
911 }
912 self.add_host(hostname, addr);
913 }
914 }
915 }
916 }
917 }
918
919 pub fn lookup_hosts(&self, name: &str) -> Option<Ipv4Address> {
921 for entry in &self.hosts {
922 if entry.name == name {
923 return Some(entry.addr);
924 }
925 }
926 None
927 }
928
929 pub fn resolve(
933 &mut self,
934 name: &str,
935 rtype: DnsRecordType,
936 ) -> Result<Vec<DnsRecord>, DnsError> {
937 if rtype == DnsRecordType::A {
939 if let Some(addr) = self.lookup_hosts(name) {
940 return Ok(vec![DnsRecord {
941 name: String::from(name),
942 rtype: DnsRecordType::A,
943 rclass: DnsClass::IN,
944 ttl: 0,
945 data: DnsRecordData::A(addr),
946 }]);
947 }
948 }
949
950 if let Some(records) = self.cache.lookup(name, rtype) {
952 return Ok(records);
953 }
954
955 let mut query_buf = [0u8; MAX_DNS_MSG_SIZE];
957 let id = self.next_id;
958 self.next_id = self.next_id.wrapping_add(1);
959 let query_len = build_query(&mut query_buf, id, name, rtype)?;
960
961 if self.nameservers.is_empty() {
963 return Err(DnsError::NoNameservers);
964 }
965
966 let mut last_err = DnsError::Timeout;
967
968 for ns_idx in 0..self.nameservers.len() {
969 let _ns_addr = self.nameservers[ns_idx];
970
971 let _ = query_len;
978
979 last_err = DnsError::Timeout;
984 }
985
986 Err(last_err)
987 }
988
989 pub fn cache_records(&mut self, records: &[DnsRecord]) {
991 for record in records {
992 self.cache.insert(&record.name, record.clone(), record.ttl);
993 }
994 }
995
996 pub fn cache_size(&self) -> usize {
998 self.cache.len()
999 }
1000
1001 pub fn clear_cache(&mut self) {
1003 self.cache.clear();
1004 }
1005}
1006
1007fn parse_ipv4(s: &str) -> Option<Ipv4Address> {
1013 let mut octets = [0u8; 4];
1014 let mut count = 0;
1015
1016 for part in s.split('.') {
1017 if count >= 4 {
1018 return None;
1019 }
1020 let val: u16 = {
1021 let mut n: u16 = 0;
1022 for &b in part.as_bytes() {
1023 if !b.is_ascii_digit() {
1024 return None;
1025 }
1026 n = n.checked_mul(10)?.checked_add((b - b'0') as u16)?;
1027 }
1028 n
1029 };
1030 if val > 255 {
1031 return None;
1032 }
1033 octets[count] = val as u8;
1034 count += 1;
1035 }
1036
1037 if count != 4 {
1038 return None;
1039 }
1040 Some(Ipv4Address::new(octets[0], octets[1], octets[2], octets[3]))
1041}
1042
1043#[cfg(feature = "alloc")]
1048static DNS_RESOLVER: crate::sync::once_lock::GlobalState<Mutex<DnsResolver>> =
1049 crate::sync::once_lock::GlobalState::new();
1050
1051#[cfg(feature = "alloc")]
1053pub fn init() -> Result<(), DnsError> {
1054 let resolver = DnsResolver::new();
1055 DNS_RESOLVER
1056 .init(Mutex::new(resolver))
1057 .map_err(|_| DnsError::InvalidFormat)?;
1058 Ok(())
1059}
1060
1061#[cfg(feature = "alloc")]
1063pub fn resolve(name: &str, rtype: DnsRecordType) -> Result<Vec<DnsRecord>, DnsError> {
1064 DNS_RESOLVER
1065 .with(|lock| {
1066 let mut resolver = lock.lock();
1067 resolver.resolve(name, rtype)
1068 })
1069 .unwrap_or(Err(DnsError::NoNameservers))
1070}
1071
1072#[cfg(feature = "alloc")]
1074pub fn add_nameserver(addr: Ipv4Address) {
1075 DNS_RESOLVER.with(|lock| {
1076 let mut resolver = lock.lock();
1077 resolver.add_nameserver(addr);
1078 });
1079}
1080
1081#[cfg(feature = "alloc")]
1083pub fn add_host(name: &str, addr: Ipv4Address) {
1084 DNS_RESOLVER.with(|lock| {
1085 let mut resolver = lock.lock();
1086 resolver.add_host(name, addr);
1087 });
1088}
1089
1090#[cfg(test)]
1095mod tests {
1096 #[allow(unused_imports)]
1097 use alloc::vec;
1098
1099 use super::*;
1100
1101 #[test]
1102 fn test_dns_record_type_roundtrip() {
1103 let types = [
1104 DnsRecordType::A,
1105 DnsRecordType::AAAA,
1106 DnsRecordType::CNAME,
1107 DnsRecordType::MX,
1108 DnsRecordType::TXT,
1109 DnsRecordType::PTR,
1110 DnsRecordType::SRV,
1111 DnsRecordType::NS,
1112 ];
1113 for t in &types {
1114 assert_eq!(DnsRecordType::from_u16(t.to_u16()), *t);
1115 }
1116 }
1117
1118 #[test]
1119 fn test_dns_record_type_unknown() {
1120 assert_eq!(DnsRecordType::from_u16(999), DnsRecordType::Unknown);
1121 }
1122
1123 #[test]
1124 fn test_encode_name_simple() {
1125 let mut buf = [0u8; 64];
1126 let len = encode_name("www.example.com", &mut buf).unwrap();
1127 assert_eq!(len, 17);
1128 assert_eq!(buf[0], 3); assert_eq!(&buf[1..4], b"www");
1130 assert_eq!(buf[4], 7); assert_eq!(&buf[5..12], b"example");
1132 assert_eq!(buf[12], 3); assert_eq!(&buf[13..16], b"com");
1134 assert_eq!(buf[16], 0); }
1136
1137 #[test]
1138 fn test_encode_name_single_label() {
1139 let mut buf = [0u8; 64];
1140 let len = encode_name("localhost", &mut buf).unwrap();
1141 assert_eq!(len, 11);
1142 assert_eq!(buf[0], 9);
1143 assert_eq!(&buf[1..10], b"localhost");
1144 assert_eq!(buf[10], 0);
1145 }
1146
1147 #[test]
1148 fn test_encode_name_empty() {
1149 let mut buf = [0u8; 64];
1150 let len = encode_name("", &mut buf).unwrap();
1151 assert_eq!(len, 1);
1152 assert_eq!(buf[0], 0);
1153 }
1154
1155 #[test]
1156 fn test_encode_name_buffer_too_small() {
1157 let mut buf = [0u8; 3];
1158 let result = encode_name("www.example.com", &mut buf);
1159 assert_eq!(result, Err(DnsError::BufferTooSmall));
1160 }
1161
1162 #[test]
1163 fn test_encode_name_label_too_long() {
1164 let long_label = "a".repeat(64);
1165 let mut buf = [0u8; 128];
1166 let result = encode_name(&long_label, &mut buf);
1167 assert_eq!(result, Err(DnsError::InvalidLabel));
1168 }
1169
1170 #[test]
1171 fn test_decode_name_simple() {
1172 let msg = [
1174 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
1175 0,
1176 ];
1177 let (name, consumed) = decode_name(&msg, 0).unwrap();
1178 assert_eq!(name, "www.example.com");
1179 assert_eq!(consumed, 17);
1180 }
1181
1182 #[test]
1183 fn test_decode_name_with_pointer() {
1184 let mut msg = vec![
1186 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
1187 ];
1188 msg.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
1190
1191 let (name, consumed) = decode_name(&msg, 13).unwrap();
1192 assert_eq!(name, "www.example.com");
1193 assert_eq!(consumed, 6); }
1195
1196 #[test]
1197 fn test_decode_name_compression_loop() {
1198 let msg = [0xC0, 0x02, 0xC0, 0x00];
1200 let result = decode_name(&msg, 0);
1201 assert_eq!(result, Err(DnsError::CompressionLoop));
1202 }
1203
1204 #[test]
1205 fn test_header_roundtrip() {
1206 let header = DnsHeader {
1207 id: 0x1234,
1208 qr: true,
1209 opcode: 0,
1210 aa: true,
1211 tc: false,
1212 rd: true,
1213 ra: true,
1214 rcode: DnsResponseCode::NoError,
1215 qdcount: 1,
1216 ancount: 2,
1217 nscount: 0,
1218 arcount: 0,
1219 };
1220
1221 let mut buf = [0u8; 12];
1222 let len = header.to_bytes(&mut buf).unwrap();
1223 assert_eq!(len, 12);
1224
1225 let parsed = DnsHeader::from_bytes(&buf).unwrap();
1226 assert_eq!(parsed.id, 0x1234);
1227 assert!(parsed.qr);
1228 assert!(parsed.aa);
1229 assert!(!parsed.tc);
1230 assert!(parsed.rd);
1231 assert!(parsed.ra);
1232 assert_eq!(parsed.qdcount, 1);
1233 assert_eq!(parsed.ancount, 2);
1234 }
1235
1236 #[test]
1237 fn test_header_too_short() {
1238 let buf = [0u8; 6];
1239 assert_eq!(DnsHeader::from_bytes(&buf), Err(DnsError::MessageTooShort));
1240 }
1241
1242 #[test]
1243 fn test_build_query() {
1244 let mut buf = [0u8; 512];
1245 let len = build_query(&mut buf, 0xABCD, "example.com", DnsRecordType::A).unwrap();
1246
1247 assert_eq!(len, 29);
1249
1250 let header = DnsHeader::from_bytes(&buf).unwrap();
1251 assert_eq!(header.id, 0xABCD);
1252 assert!(!header.qr);
1253 assert!(header.rd);
1254 assert_eq!(header.qdcount, 1);
1255 }
1256
1257 #[test]
1258 fn test_parse_response_a_record() {
1259 let mut msg = vec![0u8; 512];
1261 let mut pos = 0;
1262
1263 msg[0..2].copy_from_slice(&1u16.to_be_bytes()); msg[2..4].copy_from_slice(&0x8180u16.to_be_bytes()); msg[4..6].copy_from_slice(&1u16.to_be_bytes()); msg[6..8].copy_from_slice(&1u16.to_be_bytes()); msg[8..10].copy_from_slice(&0u16.to_be_bytes()); msg[10..12].copy_from_slice(&0u16.to_be_bytes()); pos = 12;
1271
1272 let name_bytes = [
1274 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
1275 ];
1276 msg[pos..pos + 13].copy_from_slice(&name_bytes);
1277 pos += 13;
1278 msg[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); pos += 2;
1280 msg[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); pos += 2;
1282
1283 msg[pos] = 0xC0;
1286 msg[pos + 1] = 0x0C; pos += 2;
1288 msg[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); pos += 2;
1290 msg[pos..pos + 2].copy_from_slice(&1u16.to_be_bytes()); pos += 2;
1292 msg[pos..pos + 4].copy_from_slice(&300u32.to_be_bytes()); pos += 4;
1294 msg[pos..pos + 2].copy_from_slice(&4u16.to_be_bytes()); pos += 2;
1296 msg[pos..pos + 4].copy_from_slice(&[93, 184, 216, 34]); pos += 4;
1298
1299 let (header, records) = parse_response(&msg[..pos]).unwrap();
1300 assert_eq!(header.ancount, 1);
1301 assert_eq!(records.len(), 1);
1302 assert_eq!(records[0].name, "example.com");
1303 assert_eq!(records[0].rtype, DnsRecordType::A);
1304 assert_eq!(records[0].ttl, 300);
1305 if let DnsRecordData::A(addr) = &records[0].data {
1306 assert_eq!(addr.0, [93, 184, 216, 34]);
1307 } else {
1308 panic!("Expected A record data");
1309 }
1310 }
1311
1312 #[test]
1313 fn test_cache_insert_and_lookup() {
1314 let mut cache = DnsCache::new();
1315
1316 let record = DnsRecord {
1317 name: String::from("example.com"),
1318 rtype: DnsRecordType::A,
1319 rclass: DnsClass::IN,
1320 ttl: 300,
1321 data: DnsRecordData::A(Ipv4Address::new(93, 184, 216, 34)),
1322 };
1323
1324 cache.insert("example.com", record, 300);
1325 assert_eq!(cache.len(), 1);
1326
1327 let result = cache.lookup("example.com", DnsRecordType::A);
1328 assert!(result.is_some());
1329 let records = result.unwrap();
1330 assert_eq!(records.len(), 1);
1331 }
1332
1333 #[test]
1334 fn test_cache_expiry() {
1335 let mut cache = DnsCache::new();
1336
1337 let record = DnsRecord {
1338 name: String::from("expire.test"),
1339 rtype: DnsRecordType::A,
1340 rclass: DnsClass::IN,
1341 ttl: 1, data: DnsRecordData::A(Ipv4Address::new(1, 2, 3, 4)),
1343 };
1344
1345 cache.insert("expire.test", record, 1);
1346 assert_eq!(cache.len(), 1);
1347
1348 for _ in 0..5 {
1350 let _ = cache.lookup("other.name", DnsRecordType::A);
1351 }
1352
1353 let result = cache.lookup("expire.test", DnsRecordType::A);
1355 assert!(result.is_none());
1356 }
1357
1358 #[test]
1359 fn test_cache_lru_eviction() {
1360 let mut cache = DnsCache::new();
1361 cache.max_entries = 4;
1363
1364 for i in 0..4u8 {
1366 let name = alloc::format!("host{}.test", i);
1367 let record = DnsRecord {
1368 name: name.clone(),
1369 rtype: DnsRecordType::A,
1370 rclass: DnsClass::IN,
1371 ttl: 1000,
1372 data: DnsRecordData::A(Ipv4Address::new(10, 0, 0, i)),
1373 };
1374 cache.insert(&name, record, 1000);
1375 }
1376 assert_eq!(cache.len(), 4);
1377
1378 let _ = cache.lookup("host1.test", DnsRecordType::A);
1380 let _ = cache.lookup("host3.test", DnsRecordType::A);
1381
1382 let record = DnsRecord {
1384 name: String::from("host4.test"),
1385 rtype: DnsRecordType::A,
1386 rclass: DnsClass::IN,
1387 ttl: 1000,
1388 data: DnsRecordData::A(Ipv4Address::new(10, 0, 0, 4)),
1389 };
1390 cache.insert("host4.test", record, 1000);
1391
1392 assert_eq!(cache.len(), 4);
1394 assert!(cache.lookup("host0.test", DnsRecordType::A).is_none());
1395 assert!(cache.lookup("host1.test", DnsRecordType::A).is_some());
1396 }
1397
1398 #[test]
1399 fn test_cache_miss() {
1400 let mut cache = DnsCache::new();
1401 let result = cache.lookup("nonexistent.com", DnsRecordType::A);
1402 assert!(result.is_none());
1403 }
1404
1405 #[test]
1406 fn test_parse_ipv4() {
1407 assert_eq!(
1408 parse_ipv4("192.168.1.1"),
1409 Some(Ipv4Address::new(192, 168, 1, 1))
1410 );
1411 assert_eq!(parse_ipv4("0.0.0.0"), Some(Ipv4Address::new(0, 0, 0, 0)));
1412 assert_eq!(
1413 parse_ipv4("255.255.255.255"),
1414 Some(Ipv4Address::new(255, 255, 255, 255))
1415 );
1416 assert_eq!(parse_ipv4("256.0.0.1"), None);
1417 assert_eq!(parse_ipv4("1.2.3"), None);
1418 assert_eq!(parse_ipv4("abc"), None);
1419 assert_eq!(parse_ipv4(""), None);
1420 }
1421
1422 #[test]
1423 fn test_resolver_hosts_lookup() {
1424 let mut resolver = DnsResolver::new();
1425 resolver.add_host("myhost.local", Ipv4Address::new(10, 0, 0, 1));
1426
1427 assert_eq!(
1428 resolver.lookup_hosts("localhost"),
1429 Some(Ipv4Address::LOCALHOST)
1430 );
1431 assert_eq!(
1432 resolver.lookup_hosts("myhost.local"),
1433 Some(Ipv4Address::new(10, 0, 0, 1))
1434 );
1435 assert_eq!(resolver.lookup_hosts("unknown.host"), None);
1436 }
1437
1438 #[test]
1439 fn test_resolver_parse_resolv_conf() {
1440 let mut resolver = DnsResolver::new();
1441 let content = "\
1442# DNS config
1443nameserver 8.8.8.8
1444nameserver 8.8.4.4
1445# nameserver 1.1.1.1
1446nameserver 9.9.9.9
1447";
1448 resolver.parse_resolv_conf(content);
1449 assert_eq!(resolver.nameservers.len(), 3);
1450 assert_eq!(resolver.nameservers[0], Ipv4Address::new(8, 8, 8, 8));
1451 assert_eq!(resolver.nameservers[1], Ipv4Address::new(8, 8, 4, 4));
1452 assert_eq!(resolver.nameservers[2], Ipv4Address::new(9, 9, 9, 9));
1453 }
1454
1455 #[test]
1456 fn test_resolver_parse_hosts_file() {
1457 let mut resolver = DnsResolver::new();
1458 let content = "\
1459127.0.0.1 localhost loopback
1460192.168.1.100 myserver.local myserver
1461# comment line
146210.0.0.1 gateway.local
1463";
1464 resolver.parse_hosts(content);
1465 assert_eq!(
1466 resolver.lookup_hosts("myserver.local"),
1467 Some(Ipv4Address::new(192, 168, 1, 100))
1468 );
1469 assert_eq!(
1470 resolver.lookup_hosts("myserver"),
1471 Some(Ipv4Address::new(192, 168, 1, 100))
1472 );
1473 assert_eq!(
1474 resolver.lookup_hosts("gateway.local"),
1475 Some(Ipv4Address::new(10, 0, 0, 1))
1476 );
1477 }
1478
1479 #[test]
1480 fn test_resolver_hosts_before_network() {
1481 let mut resolver = DnsResolver::new();
1482
1483 let result = resolver.resolve("localhost", DnsRecordType::A);
1485 assert!(result.is_ok());
1486 let records = result.unwrap();
1487 assert_eq!(records.len(), 1);
1488 if let DnsRecordData::A(addr) = &records[0].data {
1489 assert_eq!(*addr, Ipv4Address::LOCALHOST);
1490 } else {
1491 panic!("Expected A record");
1492 }
1493 }
1494
1495 #[test]
1496 fn test_response_code_parse() {
1497 assert_eq!(DnsResponseCode::from_u8(0), DnsResponseCode::NoError);
1498 assert_eq!(DnsResponseCode::from_u8(3), DnsResponseCode::NameError);
1499 assert_eq!(DnsResponseCode::from_u8(5), DnsResponseCode::Refused);
1500 assert_eq!(
1501 DnsResponseCode::from_u8(0xFF),
1502 DnsResponseCode::ServerFailure
1503 );
1504 }
1505
1506 #[test]
1507 fn test_dns_class_roundtrip() {
1508 assert_eq!(DnsClass::from_u16(1), DnsClass::IN);
1509 assert_eq!(DnsClass::from_u16(255), DnsClass::ANY);
1510 assert_eq!(DnsClass::IN.to_u16(), 1);
1511 }
1512}