1#![allow(clippy::manual_clamp)]
8
9use alloc::{collections::VecDeque, vec::Vec};
10use core::sync::atomic::{AtomicUsize, Ordering};
11
12use spin::Mutex;
13
14use super::{IpAddress, SocketAddr};
15use crate::error::KernelError;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SocketDomain {
20 Inet,
22 Inet6,
24 Unix,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum SocketType {
31 Stream,
33 Dgram,
35 Raw,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum SocketProtocol {
42 Default,
44 Tcp,
46 Udp,
48 Icmp,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum SocketState {
55 Unbound,
56 Bound,
57 Listening,
58 Connected,
59 Closed,
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct SocketOptions {
65 pub reuse_addr: bool,
66 pub reuse_port: bool,
67 pub broadcast: bool,
68 pub keepalive: bool,
69 pub recv_buffer_size: usize,
70 pub send_buffer_size: usize,
71 pub recv_timeout_ms: Option<u64>,
72 pub send_timeout_ms: Option<u64>,
73}
74
75impl Default for SocketOptions {
76 fn default() -> Self {
77 Self {
78 reuse_addr: false,
79 reuse_port: false,
80 broadcast: false,
81 keepalive: false,
82 recv_buffer_size: 65536,
83 send_buffer_size: 65536,
84 recv_timeout_ms: None,
85 send_timeout_ms: None,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct PendingConnection {
93 pub remote_addr: SocketAddr,
95 pub seq_num: u32,
97}
98
99#[derive(Debug, Clone)]
101pub struct Socket {
102 pub id: usize,
103 pub domain: SocketDomain,
104 pub socket_type: SocketType,
105 pub protocol: SocketProtocol,
106 pub state: SocketState,
107 pub local_addr: Option<SocketAddr>,
108 pub remote_addr: Option<SocketAddr>,
109 pub options: SocketOptions,
110 recv_buffer: Vec<u8>,
112 send_buffer: Vec<u8>,
114 backlog: usize,
116}
117
118impl Socket {
119 pub fn new(
121 domain: SocketDomain,
122 socket_type: SocketType,
123 protocol: SocketProtocol,
124 ) -> Result<Self, KernelError> {
125 match (domain, socket_type, protocol) {
127 (SocketDomain::Inet, SocketType::Stream, SocketProtocol::Tcp)
128 | (SocketDomain::Inet, SocketType::Stream, SocketProtocol::Default)
129 | (SocketDomain::Inet, SocketType::Dgram, SocketProtocol::Udp)
130 | (SocketDomain::Inet, SocketType::Dgram, SocketProtocol::Default)
131 | (SocketDomain::Inet, SocketType::Raw, _)
132 | (SocketDomain::Inet6, SocketType::Stream, SocketProtocol::Tcp)
133 | (SocketDomain::Inet6, SocketType::Stream, SocketProtocol::Default)
134 | (SocketDomain::Inet6, SocketType::Dgram, SocketProtocol::Udp)
135 | (SocketDomain::Inet6, SocketType::Dgram, SocketProtocol::Default)
136 | (SocketDomain::Inet6, SocketType::Raw, _) => {}
137 _ => {
138 return Err(KernelError::InvalidArgument {
139 name: "socket_combination",
140 value: "unsupported",
141 })
142 }
143 }
144
145 Ok(Self {
146 id: 0, domain,
148 socket_type,
149 protocol,
150 state: SocketState::Unbound,
151 local_addr: None,
152 remote_addr: None,
153 options: SocketOptions::default(),
154 recv_buffer: Vec::new(),
155 send_buffer: Vec::new(),
156 backlog: 0,
157 })
158 }
159
160 pub fn bind(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
162 if self.state != SocketState::Unbound {
163 return Err(KernelError::InvalidState {
164 expected: "unbound",
165 actual: "already_bound",
166 });
167 }
168
169 if !self.options.reuse_addr && is_address_in_use(&addr) {
171 return Err(KernelError::ResourceExhausted {
172 resource: "socket_address",
173 });
174 }
175
176 self.local_addr = Some(addr);
177 self.state = SocketState::Bound;
178 Ok(())
179 }
180
181 pub fn listen(&mut self, backlog: usize) -> Result<(), KernelError> {
183 if self.socket_type != SocketType::Stream {
184 return Err(KernelError::InvalidArgument {
185 name: "socket_type",
186 value: "not_stream",
187 });
188 }
189
190 if self.state != SocketState::Bound {
191 return Err(KernelError::InvalidState {
192 expected: "bound",
193 actual: "not_bound",
194 });
195 }
196
197 self.backlog = backlog.max(1).min(128); if let Some(addr) = self.local_addr {
202 register_listening_socket(self.id, addr, self.backlog);
203 }
204
205 self.state = SocketState::Listening;
206 Ok(())
207 }
208
209 pub fn connect(&mut self, addr: SocketAddr) -> Result<(), KernelError> {
211 match self.state {
212 SocketState::Unbound | SocketState::Bound => {}
213 _ => {
214 return Err(KernelError::InvalidState {
215 expected: "unbound_or_bound",
216 actual: "other",
217 })
218 }
219 }
220
221 if self.state == SocketState::Unbound {
223 let local_addr = match addr.ip() {
224 IpAddress::V4(_) => SocketAddr::v4(super::Ipv4Address::UNSPECIFIED, 0),
225 IpAddress::V6(_) => SocketAddr::v6(super::Ipv6Address::UNSPECIFIED, 0),
226 };
227 self.bind(local_addr)?;
228 }
229
230 self.remote_addr = Some(addr);
231
232 match self.socket_type {
234 SocketType::Stream => {
235 self.state = SocketState::Connected;
239
240 self.recv_buffer.reserve(self.options.recv_buffer_size);
242 self.send_buffer.reserve(self.options.send_buffer_size);
243 }
244 SocketType::Dgram => {
245 self.state = SocketState::Connected;
247 }
248 SocketType::Raw => {
249 self.state = SocketState::Connected;
250 }
251 }
252
253 Ok(())
254 }
255
256 pub fn accept(&self) -> Result<(Socket, SocketAddr), KernelError> {
258 if self.socket_type != SocketType::Stream {
259 return Err(KernelError::InvalidArgument {
260 name: "socket_type",
261 value: "not_stream",
262 });
263 }
264
265 if self.state != SocketState::Listening {
266 return Err(KernelError::InvalidState {
267 expected: "listening",
268 actual: "not_listening",
269 });
270 }
271
272 if let Some(pending) = dequeue_pending_connection(self.id) {
274 let mut new_socket = Socket::new(self.domain, self.socket_type, self.protocol)?;
276 new_socket.local_addr = self.local_addr;
277 new_socket.remote_addr = Some(pending.remote_addr);
278 new_socket.state = SocketState::Connected;
279 new_socket
280 .recv_buffer
281 .reserve(self.options.recv_buffer_size);
282 new_socket
283 .send_buffer
284 .reserve(self.options.send_buffer_size);
285 new_socket.options = self.options;
286
287 Ok((new_socket, pending.remote_addr))
288 } else {
289 Err(KernelError::WouldBlock)
291 }
292 }
293
294 pub fn send(&mut self, data: &[u8], _flags: u32) -> Result<usize, KernelError> {
296 if self.state != SocketState::Connected {
297 return Err(KernelError::InvalidState {
298 expected: "connected",
299 actual: "not_connected",
300 });
301 }
302
303 let remote = self.remote_addr.ok_or(KernelError::InvalidState {
304 expected: "remote_addr_set",
305 actual: "no_remote_addr",
306 })?;
307
308 match self.socket_type {
309 SocketType::Stream => {
310 let send_len = data
312 .len()
313 .min(self.options.send_buffer_size - self.send_buffer.len());
314 if send_len == 0 && !data.is_empty() {
315 return Err(KernelError::WouldBlock);
316 }
317 self.send_buffer.extend_from_slice(&data[..send_len]);
318
319 super::tcp::transmit_data(self.id, &self.send_buffer, remote);
321 let sent = self.send_buffer.len();
322 self.send_buffer.clear();
323
324 Ok(sent)
325 }
326 SocketType::Dgram => {
327 super::udp::UdpSocket::new().send_to(data, remote)
329 }
330 SocketType::Raw => Err(KernelError::NotImplemented {
331 feature: "raw_socket_send",
332 }),
333 }
334 }
335
336 pub fn send_to(
338 &self,
339 data: &[u8],
340 dest: SocketAddr,
341 _flags: u32,
342 ) -> Result<usize, KernelError> {
343 if self.socket_type != SocketType::Dgram {
344 return Err(KernelError::InvalidArgument {
345 name: "socket_type",
346 value: "not_dgram",
347 });
348 }
349
350 super::udp::UdpSocket::new().send_to(data, dest)
352 }
353
354 pub fn recv(&mut self, buffer: &mut [u8], _flags: u32) -> Result<usize, KernelError> {
356 if self.state != SocketState::Connected {
357 return Err(KernelError::InvalidState {
358 expected: "connected",
359 actual: "not_connected",
360 });
361 }
362
363 if self.recv_buffer.is_empty() {
365 let received = super::tcp::receive_data(self.id, &mut self.recv_buffer);
367 if received == 0 {
368 return Err(KernelError::WouldBlock);
369 }
370 }
371
372 let copy_len = buffer.len().min(self.recv_buffer.len());
374 buffer[..copy_len].copy_from_slice(&self.recv_buffer[..copy_len]);
375 self.recv_buffer.drain(..copy_len);
376
377 Ok(copy_len)
378 }
379
380 pub fn recv_from(
382 &mut self,
383 buffer: &mut [u8],
384 _flags: u32,
385 ) -> Result<(usize, SocketAddr), KernelError> {
386 if self.state == SocketState::Unbound {
387 return Err(KernelError::InvalidState {
388 expected: "bound",
389 actual: "unbound",
390 });
391 }
392
393 if let Some(remote) = self.remote_addr {
395 let len = self.recv(buffer, _flags)?;
396 return Ok((len, remote));
397 }
398
399 if self.socket_type == SocketType::Dgram {
401 let (len, from_addr) = super::udp::receive_from(self.id, buffer)?;
402 return Ok((len, from_addr));
403 }
404
405 Err(KernelError::WouldBlock)
406 }
407
408 pub fn close(&mut self) -> Result<(), KernelError> {
410 match self.state {
412 SocketState::Connected => {
413 if self.socket_type == SocketType::Stream {
414 super::tcp::close_connection(self.id);
416 }
417 }
418 SocketState::Listening => {
419 unregister_listening_socket(self.id);
421 }
422 _ => {}
423 }
424
425 self.recv_buffer.clear();
427 self.send_buffer.clear();
428
429 self.state = SocketState::Closed;
430 Ok(())
431 }
432
433 pub fn set_option(&mut self, option: SocketOption) -> Result<(), KernelError> {
435 match option {
436 SocketOption::ReuseAddr(val) => self.options.reuse_addr = val,
437 SocketOption::ReusePort(val) => self.options.reuse_port = val,
438 SocketOption::Broadcast(val) => self.options.broadcast = val,
439 SocketOption::KeepAlive(val) => self.options.keepalive = val,
440 SocketOption::RecvBufferSize(val) => self.options.recv_buffer_size = val,
441 SocketOption::SendBufferSize(val) => self.options.send_buffer_size = val,
442 SocketOption::RecvTimeout(val) => self.options.recv_timeout_ms = val,
443 SocketOption::SendTimeout(val) => self.options.send_timeout_ms = val,
444 }
445 Ok(())
446 }
447}
448
449#[derive(Debug, Clone)]
451pub enum SocketOption {
452 ReuseAddr(bool),
453 ReusePort(bool),
454 Broadcast(bool),
455 KeepAlive(bool),
456 RecvBufferSize(usize),
457 SendBufferSize(usize),
458 RecvTimeout(Option<u64>),
459 SendTimeout(Option<u64>),
460}
461
462static SOCKET_TABLE: Mutex<Option<Vec<Socket>>> = Mutex::new(None);
464static NEXT_SOCKET_ID: AtomicUsize = AtomicUsize::new(1);
465
466pub fn init() -> Result<(), KernelError> {
468 println!("[SOCKET] Initializing socket subsystem...");
469
470 let mut table = SOCKET_TABLE.lock();
471 *table = Some(Vec::new());
472
473 println!("[SOCKET] Socket subsystem initialized");
474 Ok(())
475}
476
477pub fn create_socket(
479 domain: SocketDomain,
480 socket_type: SocketType,
481 protocol: SocketProtocol,
482) -> Result<usize, KernelError> {
483 let mut socket = Socket::new(domain, socket_type, protocol)?;
484
485 let id = NEXT_SOCKET_ID.fetch_add(1, Ordering::Relaxed);
486 socket.id = id;
487
488 let mut table = SOCKET_TABLE.lock();
489 if let Some(ref mut sockets) = *table {
490 sockets.push(socket);
491 Ok(id)
492 } else {
493 Err(KernelError::InvalidState {
494 expected: "initialized",
495 actual: "not_initialized",
496 })
497 }
498}
499
500pub fn with_socket<R, F: FnOnce(&Socket) -> R>(id: usize, f: F) -> Result<R, KernelError> {
502 let table = SOCKET_TABLE.lock();
503 if let Some(ref sockets) = *table {
504 sockets
505 .iter()
506 .find(|s| s.id == id)
507 .map(f)
508 .ok_or(KernelError::InvalidArgument {
509 name: "socket_id",
510 value: "not_found",
511 })
512 } else {
513 Err(KernelError::InvalidState {
514 expected: "initialized",
515 actual: "not_initialized",
516 })
517 }
518}
519
520pub fn with_socket_mut<R, F: FnOnce(&mut Socket) -> R>(id: usize, f: F) -> Result<R, KernelError> {
522 let mut table = SOCKET_TABLE.lock();
523 if let Some(ref mut sockets) = *table {
524 sockets
525 .iter_mut()
526 .find(|s| s.id == id)
527 .map(f)
528 .ok_or(KernelError::InvalidArgument {
529 name: "socket_id",
530 value: "not_found",
531 })
532 } else {
533 Err(KernelError::InvalidState {
534 expected: "initialized",
535 actual: "not_initialized",
536 })
537 }
538}
539
540struct ListeningSocketEntry {
542 socket_id: usize,
543 addr: SocketAddr,
544 backlog: usize,
545 pending_connections: VecDeque<PendingConnection>,
546}
547
548static LISTENING_SOCKETS: Mutex<Vec<ListeningSocketEntry>> = Mutex::new(Vec::new());
549static BOUND_ADDRESSES: Mutex<Vec<SocketAddr>> = Mutex::new(Vec::new());
550
551fn is_address_in_use(addr: &SocketAddr) -> bool {
553 let bound = BOUND_ADDRESSES.lock();
554 bound.iter().any(|a| a == addr)
555}
556
557fn register_listening_socket(socket_id: usize, addr: SocketAddr, backlog: usize) {
559 {
561 let mut bound = BOUND_ADDRESSES.lock();
562 if !bound.iter().any(|a| a == &addr) {
563 bound.push(addr);
564 }
565 }
566
567 let mut listeners = LISTENING_SOCKETS.lock();
569 listeners.push(ListeningSocketEntry {
570 socket_id,
571 addr,
572 backlog,
573 pending_connections: VecDeque::with_capacity(backlog),
574 });
575}
576
577fn unregister_listening_socket(socket_id: usize) {
579 let mut listeners = LISTENING_SOCKETS.lock();
580 if let Some(pos) = listeners.iter().position(|e| e.socket_id == socket_id) {
581 let entry = listeners.remove(pos);
582
583 let mut bound = BOUND_ADDRESSES.lock();
585 if let Some(pos) = bound.iter().position(|a| a == &entry.addr) {
586 bound.remove(pos);
587 }
588 }
589}
590
591fn dequeue_pending_connection(socket_id: usize) -> Option<PendingConnection> {
593 let mut listeners = LISTENING_SOCKETS.lock();
594 for entry in listeners.iter_mut() {
595 if entry.socket_id == socket_id {
596 return entry.pending_connections.pop_front();
597 }
598 }
599 None
600}
601
602pub fn queue_pending_connection(
604 addr: SocketAddr,
605 remote: SocketAddr,
606 seq_num: u32,
607) -> Result<(), KernelError> {
608 let mut listeners = LISTENING_SOCKETS.lock();
609 for entry in listeners.iter_mut() {
610 if entry.addr == addr {
611 if entry.pending_connections.len() < entry.backlog {
612 entry.pending_connections.push_back(PendingConnection {
613 remote_addr: remote,
614 seq_num,
615 });
616 return Ok(());
617 } else {
618 return Err(KernelError::ResourceExhausted {
619 resource: "listen_backlog",
620 });
621 }
622 }
623 }
624 Err(KernelError::InvalidArgument {
625 name: "listen_addr",
626 value: "not_listening",
627 })
628}
629
630pub fn close_socket(id: usize) -> Result<(), KernelError> {
632 with_socket_mut(id, |socket| socket.close())?
633}
634
635#[derive(Debug, Clone)]
637pub struct SocketSummary {
638 pub id: usize,
639 pub domain: SocketDomain,
640 pub socket_type: SocketType,
641 pub state: SocketState,
642 pub local_addr: Option<SocketAddr>,
643 pub remote_addr: Option<SocketAddr>,
644}
645
646pub fn list_sockets() -> Vec<SocketSummary> {
648 let table = SOCKET_TABLE.lock();
649 if let Some(ref sockets) = *table {
650 sockets
651 .iter()
652 .map(|s| SocketSummary {
653 id: s.id,
654 domain: s.domain,
655 socket_type: s.socket_type,
656 state: s.state,
657 local_addr: s.local_addr,
658 remote_addr: s.remote_addr,
659 })
660 .collect()
661 } else {
662 Vec::new()
663 }
664}
665
666pub fn sendto(
672 id: usize,
673 data: &[u8],
674 dest: Option<&crate::net::SocketAddr>,
675) -> Result<usize, KernelError> {
676 with_socket_mut(id, |socket| {
677 if let Some(addr) = dest {
678 socket.send_to(data, *addr, 0)
679 } else {
680 socket.send(data, 0)
681 }
682 })?
683}
684
685pub fn recvfrom(
687 id: usize,
688 buf: &mut [u8],
689) -> Result<(usize, Option<crate::net::SocketAddr>), KernelError> {
690 let result = with_socket_mut(id, |socket| socket.recv_from(buf, 0))??;
691 Ok((result.0, Some(result.1)))
692}
693
694pub fn getsockname(id: usize) -> Result<crate::net::SocketAddr, KernelError> {
696 with_socket(id, |socket| {
697 socket.local_addr.ok_or(KernelError::InvalidState {
698 expected: "bound socket",
699 actual: "unbound",
700 })
701 })?
702}
703
704pub fn getpeername(id: usize) -> Result<crate::net::SocketAddr, KernelError> {
706 with_socket(id, |socket| {
707 socket.remote_addr.ok_or(KernelError::InvalidState {
708 expected: "connected socket",
709 actual: "not connected",
710 })
711 })?
712}
713
714pub fn setsockopt(
716 _id: usize,
717 _level: i32,
718 _optname: i32,
719 _optval_ptr: usize,
720 _optlen: usize,
721) -> Result<usize, KernelError> {
722 Ok(0)
724}
725
726pub fn getsockopt(
728 _id: usize,
729 _level: i32,
730 _optname: i32,
731 _optval_ptr: usize,
732) -> Result<usize, KernelError> {
733 Ok(0)
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740 use crate::net::{Ipv4Address, Ipv6Address};
741
742 #[test]
743 fn test_socket_creation() {
744 let socket =
745 Socket::new(SocketDomain::Inet, SocketType::Stream, SocketProtocol::Tcp).unwrap();
746 assert_eq!(socket.state, SocketState::Unbound);
747 assert_eq!(socket.socket_type, SocketType::Stream);
748 }
749
750 #[test]
751 fn test_socket_creation_inet6() {
752 let socket =
753 Socket::new(SocketDomain::Inet6, SocketType::Stream, SocketProtocol::Tcp).unwrap();
754 assert_eq!(socket.state, SocketState::Unbound);
755 assert_eq!(socket.domain, SocketDomain::Inet6);
756 assert_eq!(socket.socket_type, SocketType::Stream);
757 }
758
759 #[test]
760 fn test_socket_creation_inet6_udp() {
761 let socket =
762 Socket::new(SocketDomain::Inet6, SocketType::Dgram, SocketProtocol::Udp).unwrap();
763 assert_eq!(socket.domain, SocketDomain::Inet6);
764 assert_eq!(socket.socket_type, SocketType::Dgram);
765 }
766
767 #[test]
768 fn test_socket_bind() {
769 let mut socket =
770 Socket::new(SocketDomain::Inet, SocketType::Stream, SocketProtocol::Tcp).unwrap();
771 let addr = SocketAddr::v4(Ipv4Address::LOCALHOST, 8080);
772
773 assert_eq!(socket.state, SocketState::Unbound);
774 socket.bind(addr).unwrap();
775 assert_eq!(socket.state, SocketState::Bound);
776 assert_eq!(socket.local_addr, Some(addr));
777 }
778
779 #[test]
780 fn test_socket_bind_inet6() {
781 let mut socket =
782 Socket::new(SocketDomain::Inet6, SocketType::Stream, SocketProtocol::Tcp).unwrap();
783 let addr = SocketAddr::v6(Ipv6Address::LOCALHOST, 8080);
784
785 assert_eq!(socket.state, SocketState::Unbound);
786 socket.bind(addr).unwrap();
787 assert_eq!(socket.state, SocketState::Bound);
788 assert_eq!(socket.local_addr, Some(addr));
789 }
790}