1use alloc::{collections::BTreeMap, vec::Vec};
4use core::sync::atomic::{AtomicUsize, Ordering};
5
6use spin::Mutex;
7
8use super::{IpAddress, SocketAddr};
9use crate::error::KernelError;
10
11static NEXT_UDP_ID: AtomicUsize = AtomicUsize::new(1);
13
14#[derive(Debug, Clone)]
16pub struct UdpHeader {
17 pub source_port: u16,
18 pub dest_port: u16,
19 pub length: u16,
20 pub checksum: u16,
21}
22
23impl UdpHeader {
24 pub const SIZE: usize = 8;
25
26 pub fn new(src_port: u16, dst_port: u16, data_len: usize) -> Self {
27 Self {
28 source_port: src_port,
29 dest_port: dst_port,
30 length: (Self::SIZE + data_len) as u16,
31 checksum: 0,
32 }
33 }
34
35 pub fn to_bytes(&self) -> [u8; 8] {
36 let mut bytes = [0u8; 8];
37 bytes[0..2].copy_from_slice(&self.source_port.to_be_bytes());
38 bytes[2..4].copy_from_slice(&self.dest_port.to_be_bytes());
39 bytes[4..6].copy_from_slice(&self.length.to_be_bytes());
40 bytes[6..8].copy_from_slice(&self.checksum.to_be_bytes());
41 bytes
42 }
43
44 pub fn from_bytes(bytes: &[u8]) -> Result<Self, KernelError> {
45 if bytes.len() < Self::SIZE {
46 return Err(KernelError::InvalidArgument {
47 name: "udp_header",
48 value: "too_short",
49 });
50 }
51
52 Ok(Self {
53 source_port: u16::from_be_bytes([bytes[0], bytes[1]]),
54 dest_port: u16::from_be_bytes([bytes[2], bytes[3]]),
55 length: u16::from_be_bytes([bytes[4], bytes[5]]),
56 checksum: u16::from_be_bytes([bytes[6], bytes[7]]),
57 })
58 }
59
60 pub fn calculate_checksum(&mut self, src: IpAddress, dst: IpAddress, data: &[u8]) {
62 self.checksum = 0;
63
64 let mut sum: u32 = 0;
66
67 if let (IpAddress::V4(src_v4), IpAddress::V4(dst_v4)) = (src, dst) {
69 sum += u16::from_be_bytes([src_v4.0[0], src_v4.0[1]]) as u32;
70 sum += u16::from_be_bytes([src_v4.0[2], src_v4.0[3]]) as u32;
71 sum += u16::from_be_bytes([dst_v4.0[0], dst_v4.0[1]]) as u32;
72 sum += u16::from_be_bytes([dst_v4.0[2], dst_v4.0[3]]) as u32;
73 sum += 17u32; sum += self.length as u32;
75 }
76
77 let header_bytes = self.to_bytes();
79 for i in 0..4 {
80 sum += u16::from_be_bytes([header_bytes[i * 2], header_bytes[i * 2 + 1]]) as u32;
81 }
82
83 for chunk in data.chunks(2) {
85 if chunk.len() == 2 {
86 sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
87 } else {
88 sum += (chunk[0] as u32) << 8;
89 }
90 }
91
92 while sum >> 16 != 0 {
94 sum = (sum & 0xFFFF) + (sum >> 16);
95 }
96
97 self.checksum = !(sum as u16);
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct UdpSocket {
104 pub local: SocketAddr,
105 pub remote: Option<SocketAddr>,
106 pub bound: bool,
107 socket_id: usize,
108}
109
110impl UdpSocket {
111 pub fn new() -> Self {
112 Self {
113 local: SocketAddr::v4(super::Ipv4Address::UNSPECIFIED, 0),
114 remote: None,
115 bound: false,
116 socket_id: NEXT_UDP_ID.fetch_add(1, Ordering::Relaxed),
117 }
118 }
119}
120
121impl Default for UdpSocket {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl Drop for UdpSocket {
128 fn drop(&mut self) {
129 if self.bound {
130 unregister_socket(self.socket_id);
131 }
132 }
133}
134
135impl UdpSocket {
136 pub fn bind(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
138 if self.bound {
139 return Err(KernelError::InvalidState {
140 expected: "unbound",
141 actual: "bound",
142 });
143 }
144
145 self.local = addr;
146 self.bound = true;
147 register_socket(self.socket_id, addr);
148 Ok(())
149 }
150
151 pub fn connect(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
153 if !self.bound {
154 return Err(KernelError::InvalidState {
155 expected: "bound",
156 actual: "unbound",
157 });
158 }
159
160 self.remote = Some(addr);
161 Ok(())
162 }
163
164 pub fn send_to(&self, data: &[u8], dest: SocketAddr) -> Result<usize, KernelError> {
166 if !self.bound {
167 return Err(KernelError::InvalidState {
168 expected: "bound",
169 actual: "unbound",
170 });
171 }
172
173 let src_port = self.local.port();
175 let dst_port = dest.port();
176 let mut header = UdpHeader::new(src_port, dst_port, data.len());
177
178 header.calculate_checksum(self.local.ip(), dest.ip(), data);
180
181 super::ip::send(dest.ip(), super::ip::IpProtocol::Udp, data)?;
183
184 Ok(data.len())
185 }
186
187 pub fn send(&self, data: &[u8]) -> Result<usize, KernelError> {
189 if let Some(remote) = self.remote {
190 self.send_to(data, remote)
191 } else {
192 Err(KernelError::InvalidState {
193 expected: "connected",
194 actual: "not_connected",
195 })
196 }
197 }
198
199 pub fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), KernelError> {
201 if !self.bound {
202 return Err(KernelError::InvalidState {
203 expected: "bound",
204 actual: "unbound",
205 });
206 }
207
208 receive_from(self.socket_id, buffer)
209 }
210
211 pub fn recv(&self, buffer: &mut [u8]) -> Result<usize, KernelError> {
213 let (len, _) = self.recv_from(buffer)?;
214 Ok(len)
215 }
216}
217
218pub fn init() -> Result<(), KernelError> {
220 println!("[UDP] Initializing UDP protocol...");
221 println!("[UDP] UDP initialized");
222 Ok(())
223}
224
225struct ReceivedDatagram {
231 data: Vec<u8>,
232 from: SocketAddr,
233}
234
235struct UdpSocketBuffer {
237 local_addr: SocketAddr,
238 recv_queue: Vec<ReceivedDatagram>,
239 max_queue_size: usize,
240}
241
242static UDP_SOCKETS: Mutex<BTreeMap<usize, UdpSocketBuffer>> = Mutex::new(BTreeMap::new());
244
245pub fn register_socket(socket_id: usize, local_addr: SocketAddr) {
247 let mut sockets = UDP_SOCKETS.lock();
248 sockets.insert(
249 socket_id,
250 UdpSocketBuffer {
251 local_addr,
252 recv_queue: Vec::new(),
253 max_queue_size: 64,
254 },
255 );
256}
257
258pub fn unregister_socket(socket_id: usize) {
260 let mut sockets = UDP_SOCKETS.lock();
261 sockets.remove(&socket_id);
262}
263
264pub fn receive_from(
266 socket_id: usize,
267 buffer: &mut [u8],
268) -> Result<(usize, SocketAddr), KernelError> {
269 let mut sockets = UDP_SOCKETS.lock();
270
271 if let Some(sock_buf) = sockets.get_mut(&socket_id) {
272 if let Some(datagram) = sock_buf.recv_queue.pop() {
273 let copy_len = buffer.len().min(datagram.data.len());
274 buffer[..copy_len].copy_from_slice(&datagram.data[..copy_len]);
275 return Ok((copy_len, datagram.from));
276 }
277 return Err(KernelError::WouldBlock);
278 }
279
280 Err(KernelError::InvalidArgument {
281 name: "socket_id",
282 value: "not_found",
283 })
284}
285
286pub fn process_packet(
288 src_addr: IpAddress,
289 dst_addr: IpAddress,
290 data: &[u8],
291) -> Result<(), KernelError> {
292 if data.len() < UdpHeader::SIZE {
293 return Err(KernelError::InvalidArgument {
294 name: "udp_packet",
295 value: "too_short",
296 });
297 }
298
299 let header = UdpHeader::from_bytes(data)?;
301
302 if data.len() < header.length as usize {
304 return Err(KernelError::InvalidArgument {
305 name: "udp_length",
306 value: "mismatch",
307 });
308 }
309
310 let payload = &data[UdpHeader::SIZE..header.length as usize];
312 let src = SocketAddr::new(src_addr, header.source_port);
313 let _dst = SocketAddr::new(dst_addr, header.dest_port);
314
315 let mut sockets = UDP_SOCKETS.lock();
317 for (_socket_id, sock_buf) in sockets.iter_mut() {
318 if sock_buf.local_addr.port() == header.dest_port || sock_buf.local_addr.port() == 0 {
319 if sock_buf.recv_queue.len() < sock_buf.max_queue_size {
321 sock_buf.recv_queue.push(ReceivedDatagram {
322 data: payload.to_vec(),
323 from: src,
324 });
325
326 #[cfg(feature = "net_debug")]
327 println!(
328 "[UDP] Queued {} bytes from {:?} for socket {} (port {})",
329 payload.len(),
330 src,
331 _socket_id,
332 _dst.port()
333 );
334
335 return Ok(());
336 } else {
337 #[cfg(feature = "net_debug")]
338 println!("[UDP] Socket {} queue full, dropping packet", _socket_id);
339 return Err(KernelError::ResourceExhausted {
340 resource: "udp_queue",
341 });
342 }
343 }
344 }
345
346 #[cfg(feature = "net_debug")]
347 println!(
348 "[UDP] No socket for port {}, dropping packet",
349 header.dest_port
350 );
351
352 Ok(())
353}
354
355pub fn send_packet(src: SocketAddr, dst: SocketAddr, data: &[u8]) -> Result<usize, KernelError> {
357 let mut header = UdpHeader::new(src.port(), dst.port(), data.len());
359
360 header.calculate_checksum(src.ip(), dst.ip(), data);
362
363 let header_bytes = header.to_bytes();
365 let mut packet = Vec::with_capacity(UdpHeader::SIZE + data.len());
366 packet.extend_from_slice(&header_bytes);
367 packet.extend_from_slice(data);
368
369 super::ip::send(dst.ip(), super::ip::IpProtocol::Udp, &packet)?;
371
372 Ok(data.len())
373}
374
375pub fn get_stats() -> UdpStats {
377 let sockets = UDP_SOCKETS.lock();
378 let mut total_queued = 0;
379 for sock in sockets.values() {
380 total_queued += sock.recv_queue.len();
381 }
382
383 UdpStats {
384 active_sockets: sockets.len(),
385 datagrams_queued: total_queued,
386 datagrams_sent: 0, datagrams_recv: 0, datagrams_dropped: 0, }
390}
391
392#[derive(Debug, Clone, Copy, Default)]
394pub struct UdpStats {
395 pub active_sockets: usize,
396 pub datagrams_queued: usize,
397 pub datagrams_sent: u64,
398 pub datagrams_recv: u64,
399 pub datagrams_dropped: u64,
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::net::Ipv4Address;
406
407 #[test]
408 fn test_udp_header() {
409 let header = UdpHeader::new(8080, 80, 100);
410 assert_eq!(header.source_port, 8080);
411 assert_eq!(header.dest_port, 80);
412 assert_eq!(header.length, 108); }
414
415 #[test]
416 fn test_udp_header_roundtrip() {
417 let header = UdpHeader::new(1234, 5678, 50);
418 let bytes = header.to_bytes();
419 let parsed = UdpHeader::from_bytes(&bytes).unwrap();
420
421 assert_eq!(parsed.source_port, 1234);
422 assert_eq!(parsed.dest_port, 5678);
423 assert_eq!(parsed.length, 58);
424 }
425
426 #[test]
427 fn test_udp_socket() {
428 let mut socket = UdpSocket::new();
429 let addr = SocketAddr::v4(Ipv4Address::LOCALHOST, 8080);
430
431 assert!(!socket.bound);
432 socket.bind(addr).unwrap();
433 assert!(socket.bound);
434 }
435}