1#![allow(dead_code)]
28
29use alloc::{string::String, vec, vec::Vec};
30
31use spin::Mutex;
32
33use super::{Ipv4Address, MacAddress};
34use crate::error::KernelError;
35
36const MAX_PAYLOAD_SIZE: usize = 4096;
42
43const MAX_QUEUE_DEPTH: usize = 64;
45
46const MAX_SOCKETS: usize = 32;
48
49const MAX_IFNAME_LEN: usize = 16;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58#[repr(u16)]
59pub enum NetlinkMessageType {
60 Noop = 0,
62
63 Error = 1,
65
66 Done = 2,
68
69 LinkUp = 16,
71
72 LinkDown = 17,
74
75 GetLink = 18,
77
78 NewLink = 19,
80
81 DelLink = 20,
83
84 AddrAdd = 32,
86
87 AddrDel = 33,
89
90 GetAddr = 34,
92
93 NewAddr = 35,
95
96 RouteAdd = 48,
98
99 RouteDel = 49,
101
102 GetRoute = 50,
104
105 NewRoute = 51,
107
108 GetLinks = 64,
110
111 GetAddrs = 65,
113
114 GetRoutes = 66,
116}
117
118impl NetlinkMessageType {
119 pub fn from_u16(val: u16) -> Option<Self> {
121 match val {
122 0 => Some(Self::Noop),
123 1 => Some(Self::Error),
124 2 => Some(Self::Done),
125 16 => Some(Self::LinkUp),
126 17 => Some(Self::LinkDown),
127 18 => Some(Self::GetLink),
128 19 => Some(Self::NewLink),
129 20 => Some(Self::DelLink),
130 32 => Some(Self::AddrAdd),
131 33 => Some(Self::AddrDel),
132 34 => Some(Self::GetAddr),
133 35 => Some(Self::NewAddr),
134 48 => Some(Self::RouteAdd),
135 49 => Some(Self::RouteDel),
136 50 => Some(Self::GetRoute),
137 51 => Some(Self::NewRoute),
138 64 => Some(Self::GetLinks),
139 65 => Some(Self::GetAddrs),
140 66 => Some(Self::GetRoutes),
141 _ => None,
142 }
143 }
144}
145
146pub mod flags {
152 pub const NLM_F_REQUEST: u16 = 0x0001;
154 pub const NLM_F_MULTI: u16 = 0x0002;
156 pub const NLM_F_ACK: u16 = 0x0004;
158 pub const NLM_F_DUMP: u16 = 0x0100;
160 pub const NLM_F_CREATE: u16 = 0x0200;
162 pub const NLM_F_REPLACE: u16 = 0x0400;
164}
165
166#[derive(Debug, Clone, Copy)]
172#[repr(C)]
173pub struct NetlinkHeader {
174 pub msg_type: u16,
176 pub flags: u16,
178 pub seq: u32,
180 pub pid: u32,
182 pub payload_len: u32,
184}
185
186impl NetlinkHeader {
187 pub const SIZE: usize = 16;
189
190 pub const fn new(msg_type: u16, flags: u16, seq: u32, pid: u32) -> Self {
192 Self {
193 msg_type,
194 flags,
195 seq,
196 pid,
197 payload_len: 0,
198 }
199 }
200
201 pub fn serialize(&self, buf: &mut [u8]) -> Result<usize, KernelError> {
203 if buf.len() < Self::SIZE {
204 return Err(KernelError::InvalidArgument {
205 name: "netlink",
206 value: "invalid",
207 });
208 }
209
210 buf[0..2].copy_from_slice(&self.msg_type.to_le_bytes());
211 buf[2..4].copy_from_slice(&self.flags.to_le_bytes());
212 buf[4..8].copy_from_slice(&self.seq.to_le_bytes());
213 buf[8..12].copy_from_slice(&self.pid.to_le_bytes());
214 buf[12..16].copy_from_slice(&self.payload_len.to_le_bytes());
215
216 Ok(Self::SIZE)
217 }
218
219 pub fn deserialize(buf: &[u8]) -> Result<Self, KernelError> {
221 if buf.len() < Self::SIZE {
222 return Err(KernelError::InvalidArgument {
223 name: "netlink",
224 value: "invalid",
225 });
226 }
227
228 Ok(Self {
229 msg_type: u16::from_le_bytes([buf[0], buf[1]]),
230 flags: u16::from_le_bytes([buf[2], buf[3]]),
231 seq: u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
232 pid: u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]),
233 payload_len: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
234 })
235 }
236}
237
238#[derive(Debug, Clone)]
244pub struct NetlinkMessage {
245 pub header: NetlinkHeader,
247 pub payload: Vec<u8>,
249}
250
251impl NetlinkMessage {
252 pub fn new(msg_type: NetlinkMessageType, flags: u16, seq: u32, pid: u32) -> Self {
254 Self {
255 header: NetlinkHeader::new(msg_type as u16, flags, seq, pid),
256 payload: Vec::new(),
257 }
258 }
259
260 pub fn error(seq: u32, pid: u32, errno: i32) -> Self {
262 let mut msg = Self::new(NetlinkMessageType::Error, 0, seq, pid);
263 msg.payload = errno.to_le_bytes().to_vec();
264 msg.header.payload_len = 4;
265 msg
266 }
267
268 pub fn done(seq: u32, pid: u32) -> Self {
270 Self::new(NetlinkMessageType::Done, 0, seq, pid)
271 }
272
273 pub fn set_payload(&mut self, data: &[u8]) -> Result<(), KernelError> {
275 if data.len() > MAX_PAYLOAD_SIZE {
276 return Err(KernelError::InvalidArgument {
277 name: "netlink",
278 value: "invalid",
279 });
280 }
281 self.payload = data.to_vec();
282 self.header.payload_len = data.len() as u32;
283 Ok(())
284 }
285
286 pub fn total_size(&self) -> usize {
288 NetlinkHeader::SIZE
289 .checked_add(self.payload.len())
290 .unwrap_or(NetlinkHeader::SIZE)
291 }
292
293 pub fn serialize(&self) -> Result<Vec<u8>, KernelError> {
295 let total = self.total_size();
296 let mut buf = vec![0u8; total];
297
298 self.header.serialize(&mut buf[..NetlinkHeader::SIZE])?;
299
300 if !self.payload.is_empty() {
301 buf[NetlinkHeader::SIZE..].copy_from_slice(&self.payload);
302 }
303
304 Ok(buf)
305 }
306
307 pub fn deserialize(buf: &[u8]) -> Result<Self, KernelError> {
309 let header = NetlinkHeader::deserialize(buf)?;
310 let payload_len = header.payload_len as usize;
311
312 if payload_len > MAX_PAYLOAD_SIZE {
313 return Err(KernelError::InvalidArgument {
314 name: "netlink",
315 value: "invalid",
316 });
317 }
318
319 let total =
320 NetlinkHeader::SIZE
321 .checked_add(payload_len)
322 .ok_or(KernelError::InvalidArgument {
323 name: "netlink",
324 value: "invalid",
325 })?;
326
327 if buf.len() < total {
328 return Err(KernelError::InvalidArgument {
329 name: "netlink",
330 value: "invalid",
331 });
332 }
333
334 let payload = if payload_len > 0 {
335 buf[NetlinkHeader::SIZE..total].to_vec()
336 } else {
337 Vec::new()
338 };
339
340 Ok(Self { header, payload })
341 }
342}
343
344#[derive(Debug, Clone)]
350pub struct LinkInfo {
351 pub index: u32,
353 pub name: String,
355 pub mac: MacAddress,
357 pub mtu: u32,
359 pub flags: u32,
361 pub if_type: u16,
363 pub speed: u32,
365}
366
367impl LinkInfo {
368 pub fn serialize(&self) -> Vec<u8> {
370 let name_bytes = self.name.as_bytes();
371 let name_len = name_bytes.len().min(MAX_IFNAME_LEN);
372
373 let mut buf = vec![0u8; 42];
376
377 buf[0..4].copy_from_slice(&self.index.to_le_bytes());
378 buf[4..6].copy_from_slice(&(name_len as u16).to_le_bytes());
379 buf[6..6 + name_len].copy_from_slice(&name_bytes[..name_len]);
380 buf[22..28].copy_from_slice(&self.mac.0);
382 buf[28..32].copy_from_slice(&self.mtu.to_le_bytes());
383 buf[32..36].copy_from_slice(&self.flags.to_le_bytes());
384 buf[36..38].copy_from_slice(&self.if_type.to_le_bytes());
385 buf[38..42].copy_from_slice(&self.speed.to_le_bytes());
386
387 buf
388 }
389
390 pub fn deserialize(buf: &[u8]) -> Result<Self, KernelError> {
392 if buf.len() < 42 {
393 return Err(KernelError::InvalidArgument {
394 name: "netlink",
395 value: "invalid",
396 });
397 }
398
399 let index = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
400 let name_len = u16::from_le_bytes([buf[4], buf[5]]) as usize;
401 let name_len = name_len.min(MAX_IFNAME_LEN);
402
403 let name = core::str::from_utf8(&buf[6..6 + name_len])
404 .map_err(|_| KernelError::InvalidArgument {
405 name: "netlink",
406 value: "invalid",
407 })?
408 .into();
409
410 let mut mac_bytes = [0u8; 6];
411 mac_bytes.copy_from_slice(&buf[22..28]);
412
413 let mtu = u32::from_le_bytes([buf[28], buf[29], buf[30], buf[31]]);
414 let flags = u32::from_le_bytes([buf[32], buf[33], buf[34], buf[35]]);
415 let if_type = u16::from_le_bytes([buf[36], buf[37]]);
416 let speed = u32::from_le_bytes([buf[38], buf[39], buf[40], buf[41]]);
417
418 Ok(Self {
419 index,
420 name,
421 mac: MacAddress(mac_bytes),
422 mtu,
423 flags,
424 if_type,
425 speed,
426 })
427 }
428}
429
430#[derive(Debug, Clone)]
436pub struct AddrInfo {
437 pub index: u32,
439 pub family: u8,
441 pub prefix_len: u8,
443 pub addr_v4: Ipv4Address,
445}
446
447impl AddrInfo {
448 pub fn serialize(&self) -> Vec<u8> {
450 let mut buf = vec![0u8; 10];
452 buf[0..4].copy_from_slice(&self.index.to_le_bytes());
453 buf[4] = self.family;
454 buf[5] = self.prefix_len;
455 buf[6..10].copy_from_slice(&self.addr_v4.0);
456 buf
457 }
458
459 pub fn deserialize(buf: &[u8]) -> Result<Self, KernelError> {
461 if buf.len() < 10 {
462 return Err(KernelError::InvalidArgument {
463 name: "netlink",
464 value: "invalid",
465 });
466 }
467
468 let index = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
469 let family = buf[4];
470 let prefix_len = buf[5];
471 let addr_v4 = Ipv4Address([buf[6], buf[7], buf[8], buf[9]]);
472
473 Ok(Self {
474 index,
475 family,
476 prefix_len,
477 addr_v4,
478 })
479 }
480}
481
482#[derive(Debug, Clone)]
488pub struct RouteInfo {
489 pub dest: Ipv4Address,
491 pub dest_prefix: u8,
493 pub gateway: Ipv4Address,
495 pub oif_index: u32,
497 pub metric: u32,
499}
500
501impl RouteInfo {
502 pub fn serialize(&self) -> Vec<u8> {
504 let mut buf = vec![0u8; 17];
506 buf[0..4].copy_from_slice(&self.dest.0);
507 buf[4] = self.dest_prefix;
508 buf[5..9].copy_from_slice(&self.gateway.0);
509 buf[9..13].copy_from_slice(&self.oif_index.to_le_bytes());
510 buf[13..17].copy_from_slice(&self.metric.to_le_bytes());
511 buf
512 }
513
514 pub fn deserialize(buf: &[u8]) -> Result<Self, KernelError> {
516 if buf.len() < 17 {
517 return Err(KernelError::InvalidArgument {
518 name: "netlink",
519 value: "invalid",
520 });
521 }
522
523 Ok(Self {
524 dest: Ipv4Address([buf[0], buf[1], buf[2], buf[3]]),
525 dest_prefix: buf[4],
526 gateway: Ipv4Address([buf[5], buf[6], buf[7], buf[8]]),
527 oif_index: u32::from_le_bytes([buf[9], buf[10], buf[11], buf[12]]),
528 metric: u32::from_le_bytes([buf[13], buf[14], buf[15], buf[16]]),
529 })
530 }
531}
532
533pub struct NetlinkSocket {
539 id: u32,
541 pid: u32,
543 rx_queue: Vec<NetlinkMessage>,
545 next_seq: u32,
547}
548
549impl NetlinkSocket {
550 pub fn new(id: u32, pid: u32) -> Self {
552 Self {
553 id,
554 pid,
555 rx_queue: Vec::new(),
556 next_seq: 1,
557 }
558 }
559
560 pub fn id(&self) -> u32 {
562 self.id
563 }
564
565 pub fn pid(&self) -> u32 {
567 self.pid
568 }
569
570 pub fn next_seq(&mut self) -> u32 {
572 let seq = self.next_seq;
573 self.next_seq = self.next_seq.wrapping_add(1);
574 seq
575 }
576
577 pub fn has_pending(&self) -> bool {
579 !self.rx_queue.is_empty()
580 }
581
582 pub fn pending_count(&self) -> usize {
584 self.rx_queue.len()
585 }
586}
587
588struct NetlinkRegistry {
593 sockets: Vec<NetlinkSocket>,
594 next_id: u32,
595}
596
597impl NetlinkRegistry {
598 const fn new() -> Self {
599 Self {
600 sockets: Vec::new(),
601 next_id: 1,
602 }
603 }
604}
605
606static REGISTRY: Mutex<NetlinkRegistry> = Mutex::new(NetlinkRegistry::new());
607
608pub fn netlink_send(socket_id: u32, msg: &NetlinkMessage) -> Result<(), KernelError> {
617 let mut registry = REGISTRY.lock();
618
619 let socket = registry
620 .sockets
621 .iter_mut()
622 .find(|s| s.id == socket_id)
623 .ok_or(KernelError::InvalidArgument {
624 name: "netlink",
625 value: "invalid",
626 })?;
627
628 let msg_type =
629 NetlinkMessageType::from_u16(msg.header.msg_type).ok_or(KernelError::InvalidArgument {
630 name: "netlink",
631 value: "invalid",
632 })?;
633
634 match msg_type {
635 NetlinkMessageType::LinkUp | NetlinkMessageType::LinkDown => {
636 if msg.payload.len() < MAX_IFNAME_LEN {
638 let ack = NetlinkMessage::done(msg.header.seq, 0);
640 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
641 socket.rx_queue.push(ack);
642 }
643 } else {
644 let err = NetlinkMessage::error(msg.header.seq, 0, -22); if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
646 socket.rx_queue.push(err);
647 }
648 }
649 }
650 NetlinkMessageType::GetLinks => {
651 let done = NetlinkMessage::done(msg.header.seq, 0);
654 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
655 socket.rx_queue.push(done);
656 }
657 }
658 NetlinkMessageType::AddrAdd | NetlinkMessageType::AddrDel => {
659 let ack = NetlinkMessage::done(msg.header.seq, 0);
661 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
662 socket.rx_queue.push(ack);
663 }
664 }
665 NetlinkMessageType::GetAddrs => {
666 let done = NetlinkMessage::done(msg.header.seq, 0);
668 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
669 socket.rx_queue.push(done);
670 }
671 }
672 NetlinkMessageType::RouteAdd | NetlinkMessageType::RouteDel => {
673 let ack = NetlinkMessage::done(msg.header.seq, 0);
675 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
676 socket.rx_queue.push(ack);
677 }
678 }
679 NetlinkMessageType::GetRoutes => {
680 let done = NetlinkMessage::done(msg.header.seq, 0);
682 if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
683 socket.rx_queue.push(done);
684 }
685 }
686 _ => {
687 let err = NetlinkMessage::error(msg.header.seq, 0, -95); if socket.rx_queue.len() < MAX_QUEUE_DEPTH {
689 socket.rx_queue.push(err);
690 }
691 }
692 }
693
694 Ok(())
695}
696
697pub fn netlink_recv(socket_id: u32) -> Result<Option<NetlinkMessage>, KernelError> {
702 let mut registry = REGISTRY.lock();
703
704 let socket = registry
705 .sockets
706 .iter_mut()
707 .find(|s| s.id == socket_id)
708 .ok_or(KernelError::InvalidArgument {
709 name: "netlink",
710 value: "invalid",
711 })?;
712
713 if socket.rx_queue.is_empty() {
714 return Ok(None);
715 }
716
717 Ok(Some(socket.rx_queue.remove(0)))
718}
719
720pub fn netlink_open(pid: u32) -> Result<u32, KernelError> {
724 let mut registry = REGISTRY.lock();
725
726 if registry.sockets.len() >= MAX_SOCKETS {
727 return Err(KernelError::ResourceExhausted {
728 resource: "netlink_sockets",
729 });
730 }
731
732 let id = registry.next_id;
733 registry.next_id = registry.next_id.wrapping_add(1);
734
735 let socket = NetlinkSocket::new(id, pid);
736 registry.sockets.push(socket);
737
738 Ok(id)
739}
740
741pub fn netlink_close(socket_id: u32) -> Result<(), KernelError> {
743 let mut registry = REGISTRY.lock();
744
745 let pos = registry
746 .sockets
747 .iter()
748 .position(|s| s.id == socket_id)
749 .ok_or(KernelError::InvalidArgument {
750 name: "netlink",
751 value: "invalid",
752 })?;
753
754 registry.sockets.remove(pos);
755 Ok(())
756}
757
758pub fn init() -> Result<(), KernelError> {
760 Ok(())
762}
763
764#[cfg(test)]
769mod tests {
770 use alloc::vec;
771
772 use super::*;
773
774 #[test]
775 fn test_header_serialize_deserialize() {
776 let header = NetlinkHeader::new(
777 NetlinkMessageType::GetLinks as u16,
778 flags::NLM_F_REQUEST | flags::NLM_F_DUMP,
779 42,
780 1000,
781 );
782
783 let mut buf = [0u8; NetlinkHeader::SIZE];
784 header.serialize(&mut buf).unwrap();
785
786 let decoded = NetlinkHeader::deserialize(&buf).unwrap();
787 assert_eq!(decoded.msg_type, NetlinkMessageType::GetLinks as u16);
788 assert_eq!(decoded.flags, flags::NLM_F_REQUEST | flags::NLM_F_DUMP);
789 assert_eq!(decoded.seq, 42);
790 assert_eq!(decoded.pid, 1000);
791 }
792
793 #[test]
794 fn test_message_serialize_deserialize() {
795 let mut msg =
796 NetlinkMessage::new(NetlinkMessageType::AddrAdd, flags::NLM_F_REQUEST, 1, 100);
797
798 let addr = AddrInfo {
799 index: 2,
800 family: 2,
801 prefix_len: 24,
802 addr_v4: Ipv4Address::new(192, 168, 1, 100),
803 };
804 msg.set_payload(&addr.serialize()).unwrap();
805
806 let bytes = msg.serialize().unwrap();
807 let decoded = NetlinkMessage::deserialize(&bytes).unwrap();
808
809 assert_eq!(decoded.header.msg_type, NetlinkMessageType::AddrAdd as u16);
810 assert_eq!(decoded.payload.len(), 10);
811
812 let decoded_addr = AddrInfo::deserialize(&decoded.payload).unwrap();
813 assert_eq!(decoded_addr.index, 2);
814 assert_eq!(decoded_addr.prefix_len, 24);
815 assert_eq!(decoded_addr.addr_v4, Ipv4Address::new(192, 168, 1, 100));
816 }
817
818 #[test]
819 fn test_link_info_serialize_deserialize() {
820 let link = LinkInfo {
821 index: 1,
822 name: String::from("eth0"),
823 mac: MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55]),
824 mtu: 1500,
825 flags: 0x1043, if_type: 1,
827 speed: 1000,
828 };
829
830 let bytes = link.serialize();
831 let decoded = LinkInfo::deserialize(&bytes).unwrap();
832
833 assert_eq!(decoded.index, 1);
834 assert_eq!(decoded.name, "eth0");
835 assert_eq!(
836 decoded.mac,
837 MacAddress::new([0x00, 0x11, 0x22, 0x33, 0x44, 0x55])
838 );
839 assert_eq!(decoded.mtu, 1500);
840 assert_eq!(decoded.flags, 0x1043);
841 assert_eq!(decoded.if_type, 1);
842 assert_eq!(decoded.speed, 1000);
843 }
844
845 #[test]
846 fn test_route_info_serialize_deserialize() {
847 let route = RouteInfo {
848 dest: Ipv4Address::new(0, 0, 0, 0),
849 dest_prefix: 0,
850 gateway: Ipv4Address::new(192, 168, 1, 1),
851 oif_index: 2,
852 metric: 100,
853 };
854
855 let bytes = route.serialize();
856 let decoded = RouteInfo::deserialize(&bytes).unwrap();
857
858 assert_eq!(decoded.dest, Ipv4Address::new(0, 0, 0, 0));
859 assert_eq!(decoded.dest_prefix, 0);
860 assert_eq!(decoded.gateway, Ipv4Address::new(192, 168, 1, 1));
861 assert_eq!(decoded.oif_index, 2);
862 assert_eq!(decoded.metric, 100);
863 }
864
865 #[test]
866 fn test_netlink_socket_open_close() {
867 let id = netlink_open(1234).unwrap();
868 assert!(id > 0);
869 netlink_close(id).unwrap();
870 }
871
872 #[test]
873 fn test_netlink_send_recv() {
874 let id = netlink_open(5678).unwrap();
875
876 let msg = NetlinkMessage::new(
877 NetlinkMessageType::GetLinks,
878 flags::NLM_F_REQUEST | flags::NLM_F_DUMP,
879 1,
880 5678,
881 );
882
883 netlink_send(id, &msg).unwrap();
884
885 let response = netlink_recv(id).unwrap();
886 assert!(response.is_some());
887
888 let resp = response.unwrap();
889 assert_eq!(resp.header.msg_type, NetlinkMessageType::Done as u16);
890
891 netlink_close(id).unwrap();
892 }
893
894 #[test]
895 fn test_netlink_recv_empty() {
896 let id = netlink_open(9999).unwrap();
897 let response = netlink_recv(id).unwrap();
898 assert!(response.is_none());
899 netlink_close(id).unwrap();
900 }
901
902 #[test]
903 fn test_message_type_from_u16() {
904 assert_eq!(
905 NetlinkMessageType::from_u16(16),
906 Some(NetlinkMessageType::LinkUp)
907 );
908 assert_eq!(
909 NetlinkMessageType::from_u16(48),
910 Some(NetlinkMessageType::RouteAdd)
911 );
912 assert_eq!(NetlinkMessageType::from_u16(999), None);
913 }
914
915 #[test]
916 fn test_error_message() {
917 let err = NetlinkMessage::error(42, 100, -22);
918 assert_eq!(err.header.msg_type, NetlinkMessageType::Error as u16);
919 assert_eq!(err.payload.len(), 4);
920
921 let errno = i32::from_le_bytes([
922 err.payload[0],
923 err.payload[1],
924 err.payload[2],
925 err.payload[3],
926 ]);
927 assert_eq!(errno, -22);
928 }
929}