1use alloc::{
15 collections::{BTreeMap, VecDeque},
16 string::{String, ToString},
17 vec::Vec,
18};
19use core::sync::atomic::{AtomicU64, Ordering};
20
21use spin::Mutex;
22
23use crate::error::{KernelError, KernelResult};
24
25pub const UNIX_PATH_MAX: usize = 108;
31
32pub const UNIX_BACKLOG_MAX: usize = 128;
34
35pub const UNIX_DGRAM_MAX: usize = 65536;
37
38pub const SCM_RIGHTS_MAX: usize = 16;
40
41pub const UNIX_SOCKET_MAX: usize = 1024;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum UnixSocketType {
51 Stream,
53 Datagram,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum UnixSocketState {
60 Unbound,
62 Bound,
64 Listening,
66 Connected,
68 Closed,
70}
71
72#[derive(Debug, Clone)]
74pub struct ScmRights {
75 pub fds: Vec<u32>,
77}
78
79#[derive(Debug, Clone)]
81pub struct UnixMessage {
82 pub data: Vec<u8>,
84 pub rights: Option<ScmRights>,
86 pub sender: u64,
88}
89
90pub struct UnixSocket {
96 pub id: u64,
98 pub socket_type: UnixSocketType,
100 pub state: UnixSocketState,
102 pub path: Option<String>,
104 pub peer_id: Option<u64>,
106 pub recv_buffer: VecDeque<UnixMessage>,
108 pub recv_buffer_max: usize,
110 pub recv_buffer_used: usize,
112 pub pending_connections: VecDeque<u64>,
114 pub backlog: usize,
116 pub shutdown_read: bool,
118 pub shutdown_write: bool,
120 pub owner_pid: u64,
122}
123
124impl UnixSocket {
125 fn new(id: u64, socket_type: UnixSocketType, owner_pid: u64) -> Self {
126 Self {
127 id,
128 socket_type,
129 state: UnixSocketState::Unbound,
130 path: None,
131 peer_id: None,
132 recv_buffer: VecDeque::new(),
133 recv_buffer_max: 65536,
134 recv_buffer_used: 0,
135 pending_connections: VecDeque::new(),
136 backlog: 0,
137 shutdown_read: false,
138 shutdown_write: false,
139 owner_pid,
140 }
141 }
142}
143
144static NEXT_SOCKET_ID: AtomicU64 = AtomicU64::new(1);
150
151static UNIX_SOCKETS: Mutex<BTreeMap<u64, UnixSocket>> = Mutex::new(BTreeMap::new());
153
154static PATH_REGISTRY: Mutex<BTreeMap<String, u64>> = Mutex::new(BTreeMap::new());
156
157pub fn socket_create(socket_type: UnixSocketType, owner_pid: u64) -> KernelResult<u64> {
165 let sockets = UNIX_SOCKETS.lock();
166 if sockets.len() >= UNIX_SOCKET_MAX {
167 return Err(KernelError::ResourceExhausted {
168 resource: "unix_sockets",
169 });
170 }
171 drop(sockets);
172
173 let id = NEXT_SOCKET_ID.fetch_add(1, Ordering::Relaxed);
174 let socket = UnixSocket::new(id, socket_type, owner_pid);
175
176 UNIX_SOCKETS.lock().insert(id, socket);
177 Ok(id)
178}
179
180pub fn socket_bind(socket_id: u64, path: &str) -> KernelResult<()> {
182 if path.is_empty() || path.len() > UNIX_PATH_MAX {
183 return Err(KernelError::InvalidArgument {
184 name: "path",
185 value: "empty or exceeds UNIX_PATH_MAX",
186 });
187 }
188
189 let mut paths = PATH_REGISTRY.lock();
190 if paths.contains_key(path) {
191 return Err(KernelError::AlreadyExists {
192 resource: "unix_socket_path",
193 id: 0,
194 });
195 }
196
197 let mut sockets = UNIX_SOCKETS.lock();
198 let socket = sockets.get_mut(&socket_id).ok_or(KernelError::NotFound {
199 resource: "unix_socket",
200 id: socket_id,
201 })?;
202
203 if socket.state != UnixSocketState::Unbound {
204 return Err(KernelError::InvalidState {
205 expected: "unbound",
206 actual: "already bound or connected",
207 });
208 }
209
210 socket.path = Some(path.to_string());
211 socket.state = UnixSocketState::Bound;
212 paths.insert(path.to_string(), socket_id);
213
214 Ok(())
215}
216
217pub fn socket_listen(socket_id: u64, backlog: usize) -> KernelResult<()> {
219 let mut sockets = UNIX_SOCKETS.lock();
220 let socket = sockets.get_mut(&socket_id).ok_or(KernelError::NotFound {
221 resource: "unix_socket",
222 id: socket_id,
223 })?;
224
225 if socket.socket_type != UnixSocketType::Stream {
226 return Err(KernelError::InvalidArgument {
227 name: "socket_type",
228 value: "listen requires SOCK_STREAM",
229 });
230 }
231
232 if socket.state != UnixSocketState::Bound {
233 return Err(KernelError::InvalidState {
234 expected: "bound",
235 actual: "not bound",
236 });
237 }
238
239 socket.backlog = backlog.min(UNIX_BACKLOG_MAX);
240 socket.state = UnixSocketState::Listening;
241 Ok(())
242}
243
244pub fn socket_connect(socket_id: u64, path: &str) -> KernelResult<()> {
249 let target_id = {
251 let paths = PATH_REGISTRY.lock();
252 *paths.get(path).ok_or(KernelError::NotFound {
253 resource: "unix_socket_path",
254 id: 0,
255 })?
256 };
257
258 let mut sockets = UNIX_SOCKETS.lock();
259
260 let target = sockets.get(&target_id).ok_or(KernelError::NotFound {
262 resource: "unix_socket",
263 id: target_id,
264 })?;
265
266 if target.state != UnixSocketState::Listening {
267 return Err(KernelError::InvalidState {
268 expected: "listening",
269 actual: "not listening",
270 });
271 }
272
273 if target.pending_connections.len() >= target.backlog {
274 return Err(KernelError::ResourceExhausted {
275 resource: "connection backlog",
276 });
277 }
278
279 let target = sockets.get_mut(&target_id).ok_or(KernelError::NotFound {
281 resource: "unix_socket_target",
282 id: target_id,
283 })?;
284 target.pending_connections.push_back(socket_id);
285
286 let socket = sockets.get_mut(&socket_id).ok_or(KernelError::NotFound {
288 resource: "unix_socket",
289 id: socket_id,
290 })?;
291 socket.peer_id = Some(target_id);
292 socket.state = UnixSocketState::Connected;
293
294 Ok(())
295}
296
297pub fn socket_accept(listen_socket_id: u64) -> KernelResult<(u64, u64)> {
302 let mut sockets = UNIX_SOCKETS.lock();
303
304 let listen = sockets
305 .get_mut(&listen_socket_id)
306 .ok_or(KernelError::NotFound {
307 resource: "unix_socket",
308 id: listen_socket_id,
309 })?;
310
311 if listen.state != UnixSocketState::Listening {
312 return Err(KernelError::InvalidState {
313 expected: "listening",
314 actual: "not listening",
315 });
316 }
317
318 let connecting_id = listen
319 .pending_connections
320 .pop_front()
321 .ok_or(KernelError::WouldBlock)?;
322
323 let owner_pid = listen.owner_pid;
324
325 let new_id = NEXT_SOCKET_ID.fetch_add(1, Ordering::Relaxed);
327 let mut new_socket = UnixSocket::new(new_id, UnixSocketType::Stream, owner_pid);
328 new_socket.state = UnixSocketState::Connected;
329 new_socket.peer_id = Some(connecting_id);
330
331 if let Some(connecting) = sockets.get_mut(&connecting_id) {
333 connecting.peer_id = Some(new_id);
334 }
335
336 sockets.insert(new_id, new_socket);
337
338 Ok((new_id, connecting_id))
339}
340
341pub fn socket_send(socket_id: u64, data: &[u8], rights: Option<ScmRights>) -> KernelResult<usize> {
343 let sockets = UNIX_SOCKETS.lock();
344 let socket = sockets.get(&socket_id).ok_or(KernelError::NotFound {
345 resource: "unix_socket",
346 id: socket_id,
347 })?;
348
349 if socket.shutdown_write {
350 return Err(KernelError::InvalidState {
351 expected: "write enabled",
352 actual: "shutdown for writing",
353 });
354 }
355
356 let peer_id = socket.peer_id.ok_or(KernelError::InvalidState {
357 expected: "connected",
358 actual: "not connected",
359 })?;
360 drop(sockets);
361
362 let mut sockets = UNIX_SOCKETS.lock();
364 let peer = sockets.get_mut(&peer_id).ok_or(KernelError::NotFound {
365 resource: "unix_socket",
366 id: peer_id,
367 })?;
368
369 if peer.shutdown_read {
370 return Err(KernelError::InvalidState {
371 expected: "read enabled",
372 actual: "peer shutdown for reading",
373 });
374 }
375
376 if peer.recv_buffer_used + data.len() > peer.recv_buffer_max {
377 return Err(KernelError::ResourceExhausted {
378 resource: "recv_buffer",
379 });
380 }
381
382 let msg = UnixMessage {
383 data: data.to_vec(),
384 rights,
385 sender: socket_id,
386 };
387 let len = data.len();
388 peer.recv_buffer_used += len;
389 peer.recv_buffer.push_back(msg);
390
391 Ok(len)
392}
393
394pub fn socket_recv(socket_id: u64, buf: &mut [u8]) -> KernelResult<(usize, Option<ScmRights>)> {
398 let mut sockets = UNIX_SOCKETS.lock();
399 let socket = sockets.get_mut(&socket_id).ok_or(KernelError::NotFound {
400 resource: "unix_socket",
401 id: socket_id,
402 })?;
403
404 if socket.shutdown_read {
405 return Ok((0, None)); }
407
408 let msg = socket
409 .recv_buffer
410 .pop_front()
411 .ok_or(KernelError::WouldBlock)?;
412
413 let copy_len = buf.len().min(msg.data.len());
414 buf[..copy_len].copy_from_slice(&msg.data[..copy_len]);
415 socket.recv_buffer_used = socket.recv_buffer_used.saturating_sub(msg.data.len());
416
417 Ok((copy_len, msg.rights))
418}
419
420pub fn socketpair(socket_type: UnixSocketType, owner_pid: u64) -> KernelResult<(u64, u64)> {
424 let id_a = NEXT_SOCKET_ID.fetch_add(1, Ordering::Relaxed);
425 let id_b = NEXT_SOCKET_ID.fetch_add(1, Ordering::Relaxed);
426
427 let mut sock_a = UnixSocket::new(id_a, socket_type, owner_pid);
428 let mut sock_b = UnixSocket::new(id_b, socket_type, owner_pid);
429
430 sock_a.state = UnixSocketState::Connected;
431 sock_a.peer_id = Some(id_b);
432 sock_b.state = UnixSocketState::Connected;
433 sock_b.peer_id = Some(id_a);
434
435 let mut sockets = UNIX_SOCKETS.lock();
436 sockets.insert(id_a, sock_a);
437 sockets.insert(id_b, sock_b);
438
439 Ok((id_a, id_b))
440}
441
442pub fn socket_close(socket_id: u64) -> KernelResult<()> {
444 let mut sockets = UNIX_SOCKETS.lock();
445
446 if let Some(socket) = sockets.remove(&socket_id) {
447 if let Some(ref path) = socket.path {
449 PATH_REGISTRY.lock().remove(path);
450 }
451
452 if let Some(peer_id) = socket.peer_id {
454 if let Some(peer) = sockets.get_mut(&peer_id) {
455 peer.peer_id = None;
456 peer.shutdown_read = true;
457 peer.shutdown_write = true;
458 }
459 }
460 }
461
462 Ok(())
463}
464
465pub fn socket_sendto(socket_id: u64, data: &[u8], dest_path: &str) -> KernelResult<usize> {
467 if data.len() > UNIX_DGRAM_MAX {
468 return Err(KernelError::InvalidArgument {
469 name: "data",
470 value: "exceeds UNIX_DGRAM_MAX",
471 });
472 }
473
474 let dest_id = {
475 let paths = PATH_REGISTRY.lock();
476 *paths.get(dest_path).ok_or(KernelError::NotFound {
477 resource: "unix_socket_path",
478 id: 0,
479 })?
480 };
481
482 let mut sockets = UNIX_SOCKETS.lock();
483 let dest = sockets.get_mut(&dest_id).ok_or(KernelError::NotFound {
484 resource: "unix_socket",
485 id: dest_id,
486 })?;
487
488 if dest.recv_buffer_used + data.len() > dest.recv_buffer_max {
489 return Err(KernelError::ResourceExhausted {
490 resource: "recv_buffer",
491 });
492 }
493
494 let msg = UnixMessage {
495 data: data.to_vec(),
496 rights: None,
497 sender: socket_id,
498 };
499 let len = data.len();
500 dest.recv_buffer_used += len;
501 dest.recv_buffer.push_back(msg);
502
503 Ok(len)
504}
505
506pub fn socket_count() -> usize {
508 UNIX_SOCKETS.lock().len()
509}