⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
udp.rs

1//! UDP protocol implementation
2
3use alloc::{collections::BTreeMap, vec::Vec};
4use core::sync::atomic::{AtomicUsize, Ordering};
5
6use spin::Mutex;
7
8use super::{IpAddress, SocketAddr};
9use crate::error::KernelError;
10
11/// Next UDP socket ID counter
12static NEXT_UDP_ID: AtomicUsize = AtomicUsize::new(1);
13
14/// UDP header
15#[derive(Debug, Clone)]
16pub struct UdpHeader {
17    pub source_port: u16,
18    pub dest_port: u16,
19    pub length: u16,
20    pub checksum: u16,
21}
22
23impl UdpHeader {
24    pub const SIZE: usize = 8;
25
26    pub fn new(src_port: u16, dst_port: u16, data_len: usize) -> Self {
27        Self {
28            source_port: src_port,
29            dest_port: dst_port,
30            length: (Self::SIZE + data_len) as u16,
31            checksum: 0,
32        }
33    }
34
35    pub fn to_bytes(&self) -> [u8; 8] {
36        let mut bytes = [0u8; 8];
37        bytes[0..2].copy_from_slice(&self.source_port.to_be_bytes());
38        bytes[2..4].copy_from_slice(&self.dest_port.to_be_bytes());
39        bytes[4..6].copy_from_slice(&self.length.to_be_bytes());
40        bytes[6..8].copy_from_slice(&self.checksum.to_be_bytes());
41        bytes
42    }
43
44    pub fn from_bytes(bytes: &[u8]) -> Result<Self, KernelError> {
45        if bytes.len() < Self::SIZE {
46            return Err(KernelError::InvalidArgument {
47                name: "udp_header",
48                value: "too_short",
49            });
50        }
51
52        Ok(Self {
53            source_port: u16::from_be_bytes([bytes[0], bytes[1]]),
54            dest_port: u16::from_be_bytes([bytes[2], bytes[3]]),
55            length: u16::from_be_bytes([bytes[4], bytes[5]]),
56            checksum: u16::from_be_bytes([bytes[6], bytes[7]]),
57        })
58    }
59
60    /// Calculate UDP checksum
61    pub fn calculate_checksum(&mut self, src: IpAddress, dst: IpAddress, data: &[u8]) {
62        self.checksum = 0;
63
64        // UDP checksum includes pseudo-header
65        let mut sum: u32 = 0;
66
67        // Add pseudo-header (source IP, dest IP, protocol, length)
68        if let (IpAddress::V4(src_v4), IpAddress::V4(dst_v4)) = (src, dst) {
69            sum += u16::from_be_bytes([src_v4.0[0], src_v4.0[1]]) as u32;
70            sum += u16::from_be_bytes([src_v4.0[2], src_v4.0[3]]) as u32;
71            sum += u16::from_be_bytes([dst_v4.0[0], dst_v4.0[1]]) as u32;
72            sum += u16::from_be_bytes([dst_v4.0[2], dst_v4.0[3]]) as u32;
73            sum += 17u32; // Protocol (UDP)
74            sum += self.length as u32;
75        }
76
77        // Add UDP header
78        let header_bytes = self.to_bytes();
79        for i in 0..4 {
80            sum += u16::from_be_bytes([header_bytes[i * 2], header_bytes[i * 2 + 1]]) as u32;
81        }
82
83        // Add data
84        for chunk in data.chunks(2) {
85            if chunk.len() == 2 {
86                sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
87            } else {
88                sum += (chunk[0] as u32) << 8;
89            }
90        }
91
92        // Fold 32-bit sum to 16 bits
93        while sum >> 16 != 0 {
94            sum = (sum & 0xFFFF) + (sum >> 16);
95        }
96
97        self.checksum = !(sum as u16);
98    }
99}
100
101/// UDP socket
102#[derive(Debug, Clone)]
103pub struct UdpSocket {
104    pub local: SocketAddr,
105    pub remote: Option<SocketAddr>,
106    pub bound: bool,
107    socket_id: usize,
108}
109
110impl UdpSocket {
111    pub fn new() -> Self {
112        Self {
113            local: SocketAddr::v4(super::Ipv4Address::UNSPECIFIED, 0),
114            remote: None,
115            bound: false,
116            socket_id: NEXT_UDP_ID.fetch_add(1, Ordering::Relaxed),
117        }
118    }
119}
120
121impl Default for UdpSocket {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl Drop for UdpSocket {
128    fn drop(&mut self) {
129        if self.bound {
130            unregister_socket(self.socket_id);
131        }
132    }
133}
134
135impl UdpSocket {
136    /// Bind to local address
137    pub fn bind(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
138        if self.bound {
139            return Err(KernelError::InvalidState {
140                expected: "unbound",
141                actual: "bound",
142            });
143        }
144
145        self.local = addr;
146        self.bound = true;
147        register_socket(self.socket_id, addr);
148        Ok(())
149    }
150
151    /// Connect to remote address (optional for UDP)
152    pub fn connect(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
153        if !self.bound {
154            return Err(KernelError::InvalidState {
155                expected: "bound",
156                actual: "unbound",
157            });
158        }
159
160        self.remote = Some(addr);
161        Ok(())
162    }
163
164    /// Send data to specific address
165    pub fn send_to(&self, data: &[u8], dest: SocketAddr) -> Result<usize, KernelError> {
166        if !self.bound {
167            return Err(KernelError::InvalidState {
168                expected: "bound",
169                actual: "unbound",
170            });
171        }
172
173        // Create UDP header
174        let src_port = self.local.port();
175        let dst_port = dest.port();
176        let mut header = UdpHeader::new(src_port, dst_port, data.len());
177
178        // Calculate checksum
179        header.calculate_checksum(self.local.ip(), dest.ip(), data);
180
181        // Send via IP layer
182        super::ip::send(dest.ip(), super::ip::IpProtocol::Udp, data)?;
183
184        Ok(data.len())
185    }
186
187    /// Send data to connected address
188    pub fn send(&self, data: &[u8]) -> Result<usize, KernelError> {
189        if let Some(remote) = self.remote {
190            self.send_to(data, remote)
191        } else {
192            Err(KernelError::InvalidState {
193                expected: "connected",
194                actual: "not_connected",
195            })
196        }
197    }
198
199    /// Receive data
200    pub fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), KernelError> {
201        if !self.bound {
202            return Err(KernelError::InvalidState {
203                expected: "bound",
204                actual: "unbound",
205            });
206        }
207
208        receive_from(self.socket_id, buffer)
209    }
210
211    /// Receive data (from connected address)
212    pub fn recv(&self, buffer: &mut [u8]) -> Result<usize, KernelError> {
213        let (len, _) = self.recv_from(buffer)?;
214        Ok(len)
215    }
216}
217
218/// Initialize UDP
219pub fn init() -> Result<(), KernelError> {
220    println!("[UDP] Initializing UDP protocol...");
221    println!("[UDP] UDP initialized");
222    Ok(())
223}
224
225// ============================================================================
226// Socket Layer Interface
227// ============================================================================
228
229/// Received UDP datagram with source address
230struct ReceivedDatagram {
231    data: Vec<u8>,
232    from: SocketAddr,
233}
234
235/// UDP receive buffer per socket
236struct UdpSocketBuffer {
237    local_addr: SocketAddr,
238    recv_queue: Vec<ReceivedDatagram>,
239    max_queue_size: usize,
240}
241
242/// Global UDP socket buffers
243static UDP_SOCKETS: Mutex<BTreeMap<usize, UdpSocketBuffer>> = Mutex::new(BTreeMap::new());
244
245/// Register a UDP socket for receiving
246pub fn register_socket(socket_id: usize, local_addr: SocketAddr) {
247    let mut sockets = UDP_SOCKETS.lock();
248    sockets.insert(
249        socket_id,
250        UdpSocketBuffer {
251            local_addr,
252            recv_queue: Vec::new(),
253            max_queue_size: 64,
254        },
255    );
256}
257
258/// Unregister a UDP socket
259pub fn unregister_socket(socket_id: usize) {
260    let mut sockets = UDP_SOCKETS.lock();
261    sockets.remove(&socket_id);
262}
263
264/// Receive data from a UDP socket (called by socket layer)
265pub fn receive_from(
266    socket_id: usize,
267    buffer: &mut [u8],
268) -> Result<(usize, SocketAddr), KernelError> {
269    let mut sockets = UDP_SOCKETS.lock();
270
271    if let Some(sock_buf) = sockets.get_mut(&socket_id) {
272        if let Some(datagram) = sock_buf.recv_queue.pop() {
273            let copy_len = buffer.len().min(datagram.data.len());
274            buffer[..copy_len].copy_from_slice(&datagram.data[..copy_len]);
275            return Ok((copy_len, datagram.from));
276        }
277        return Err(KernelError::WouldBlock);
278    }
279
280    Err(KernelError::InvalidArgument {
281        name: "socket_id",
282        value: "not_found",
283    })
284}
285
286/// Process incoming UDP packet (called by IP layer)
287pub fn process_packet(
288    src_addr: IpAddress,
289    dst_addr: IpAddress,
290    data: &[u8],
291) -> Result<(), KernelError> {
292    if data.len() < UdpHeader::SIZE {
293        return Err(KernelError::InvalidArgument {
294            name: "udp_packet",
295            value: "too_short",
296        });
297    }
298
299    // Parse UDP header
300    let header = UdpHeader::from_bytes(data)?;
301
302    // Validate length
303    if data.len() < header.length as usize {
304        return Err(KernelError::InvalidArgument {
305            name: "udp_length",
306            value: "mismatch",
307        });
308    }
309
310    // Extract payload
311    let payload = &data[UdpHeader::SIZE..header.length as usize];
312    let src = SocketAddr::new(src_addr, header.source_port);
313    let _dst = SocketAddr::new(dst_addr, header.dest_port);
314
315    // Find matching socket by destination port
316    let mut sockets = UDP_SOCKETS.lock();
317    for (_socket_id, sock_buf) in sockets.iter_mut() {
318        if sock_buf.local_addr.port() == header.dest_port || sock_buf.local_addr.port() == 0 {
319            // Check queue size
320            if sock_buf.recv_queue.len() < sock_buf.max_queue_size {
321                sock_buf.recv_queue.push(ReceivedDatagram {
322                    data: payload.to_vec(),
323                    from: src,
324                });
325
326                #[cfg(feature = "net_debug")]
327                println!(
328                    "[UDP] Queued {} bytes from {:?} for socket {} (port {})",
329                    payload.len(),
330                    src,
331                    _socket_id,
332                    _dst.port()
333                );
334
335                return Ok(());
336            } else {
337                #[cfg(feature = "net_debug")]
338                println!("[UDP] Socket {} queue full, dropping packet", _socket_id);
339                return Err(KernelError::ResourceExhausted {
340                    resource: "udp_queue",
341                });
342            }
343        }
344    }
345
346    #[cfg(feature = "net_debug")]
347    println!(
348        "[UDP] No socket for port {}, dropping packet",
349        header.dest_port
350    );
351
352    Ok(())
353}
354
355/// Send UDP packet (internal implementation)
356pub fn send_packet(src: SocketAddr, dst: SocketAddr, data: &[u8]) -> Result<usize, KernelError> {
357    // Create UDP header
358    let mut header = UdpHeader::new(src.port(), dst.port(), data.len());
359
360    // Calculate checksum
361    header.calculate_checksum(src.ip(), dst.ip(), data);
362
363    // Build packet: header + data
364    let header_bytes = header.to_bytes();
365    let mut packet = Vec::with_capacity(UdpHeader::SIZE + data.len());
366    packet.extend_from_slice(&header_bytes);
367    packet.extend_from_slice(data);
368
369    // Send via IP layer
370    super::ip::send(dst.ip(), super::ip::IpProtocol::Udp, &packet)?;
371
372    Ok(data.len())
373}
374
375/// Get UDP statistics
376pub fn get_stats() -> UdpStats {
377    let sockets = UDP_SOCKETS.lock();
378    let mut total_queued = 0;
379    for sock in sockets.values() {
380        total_queued += sock.recv_queue.len();
381    }
382
383    UdpStats {
384        active_sockets: sockets.len(),
385        datagrams_queued: total_queued,
386        datagrams_sent: 0,    // Would track in real implementation
387        datagrams_recv: 0,    // Would track in real implementation
388        datagrams_dropped: 0, // Would track in real implementation
389    }
390}
391
392/// UDP statistics
393#[derive(Debug, Clone, Copy, Default)]
394pub struct UdpStats {
395    pub active_sockets: usize,
396    pub datagrams_queued: usize,
397    pub datagrams_sent: u64,
398    pub datagrams_recv: u64,
399    pub datagrams_dropped: u64,
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::net::Ipv4Address;
406
407    #[test]
408    fn test_udp_header() {
409        let header = UdpHeader::new(8080, 80, 100);
410        assert_eq!(header.source_port, 8080);
411        assert_eq!(header.dest_port, 80);
412        assert_eq!(header.length, 108); // 8 + 100
413    }
414
415    #[test]
416    fn test_udp_header_roundtrip() {
417        let header = UdpHeader::new(1234, 5678, 50);
418        let bytes = header.to_bytes();
419        let parsed = UdpHeader::from_bytes(&bytes).unwrap();
420
421        assert_eq!(parsed.source_port, 1234);
422        assert_eq!(parsed.dest_port, 5678);
423        assert_eq!(parsed.length, 58);
424    }
425
426    #[test]
427    fn test_udp_socket() {
428        let mut socket = UdpSocket::new();
429        let addr = SocketAddr::v4(Ipv4Address::LOCALHOST, 8080);
430
431        assert!(!socket.bound);
432        socket.bind(addr).unwrap();
433        assert!(socket.bound);
434    }
435}