1use alloc::{collections::BTreeMap, vec::Vec};
7
8use spin::Mutex;
9
10use super::{IpAddress, SocketAddr};
11use crate::error::KernelError;
12
13const TCP_MSS: u16 = 1460;
15
16const TCP_HEADER_SIZE: usize = 20;
18
19#[derive(Debug, Clone, Copy)]
21pub struct TcpFlags(u8);
22
23impl TcpFlags {
24 pub const FIN: u8 = 0x01;
25 pub const SYN: u8 = 0x02;
26 pub const RST: u8 = 0x04;
27 pub const PSH: u8 = 0x08;
28 pub const ACK: u8 = 0x10;
29 pub const URG: u8 = 0x20;
30
31 pub fn new(flags: u8) -> Self {
32 Self(flags)
33 }
34
35 pub fn has(&self, flag: u8) -> bool {
36 (self.0 & flag) != 0
37 }
38
39 pub fn set(&mut self, flag: u8) {
40 self.0 |= flag;
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum TcpState {
47 Closed,
48 Listen,
49 SynSent,
50 SynReceived,
51 Established,
52 FinWait1,
53 FinWait2,
54 CloseWait,
55 Closing,
56 LastAck,
57 TimeWait,
58}
59
60#[derive(Debug, Clone)]
62pub struct TcpConnection {
63 pub local: SocketAddr,
64 pub remote: SocketAddr,
65 pub state: TcpState,
66 pub seq_num: u32,
67 pub ack_num: u32,
68 pub window_size: u16,
69}
70
71impl TcpConnection {
72 pub fn new(local: SocketAddr, remote: SocketAddr) -> Self {
73 Self {
74 local,
75 remote,
76 state: TcpState::Closed,
77 seq_num: 0,
78 ack_num: 0,
79 window_size: 65535,
80 }
81 }
82
83 pub fn connect(&mut self) -> Result<(), KernelError> {
85 if self.state != TcpState::Closed {
86 return Err(KernelError::InvalidState {
87 expected: "Closed",
88 actual: "Other",
89 });
90 }
91
92 self.seq_num = generate_initial_seq();
93 let syn = build_tcp_segment(
95 self.local.port(),
96 self.remote.port(),
97 self.seq_num,
98 0,
99 TcpFlags::SYN,
100 self.window_size,
101 &[],
102 );
103 send_tcp_via_ip(self.remote.ip(), &syn)?;
104 self.seq_num = self.seq_num.wrapping_add(1); self.state = TcpState::SynSent;
106
107 Ok(())
108 }
109
110 pub fn listen(&mut self) -> Result<(), KernelError> {
112 if self.state != TcpState::Closed {
113 return Err(KernelError::InvalidState {
114 expected: "Closed",
115 actual: "Other",
116 });
117 }
118
119 self.state = TcpState::Listen;
120 Ok(())
121 }
122
123 pub fn send(&mut self, data: &[u8]) -> Result<usize, KernelError> {
125 if self.state != TcpState::Established {
126 return Err(KernelError::InvalidState {
127 expected: "Established",
128 actual: "Other",
129 });
130 }
131
132 let mss = TCP_MSS as usize;
133 let mut offset = 0;
134 while offset < data.len() {
135 let end = (offset + mss).min(data.len());
136 let chunk = &data[offset..end];
137
138 let flags = TcpFlags::ACK | if end == data.len() { TcpFlags::PSH } else { 0 };
139 let seg = build_tcp_segment(
140 self.local.port(),
141 self.remote.port(),
142 self.seq_num,
143 self.ack_num,
144 flags,
145 self.window_size,
146 chunk,
147 );
148 let _ = send_tcp_via_ip(self.remote.ip(), &seg);
149
150 self.seq_num = self.seq_num.wrapping_add(chunk.len() as u32);
151 offset = end;
152 }
153
154 Ok(data.len())
155 }
156
157 pub fn recv(&mut self, _buffer: &mut [u8]) -> Result<usize, KernelError> {
159 if self.state != TcpState::Established {
160 return Err(KernelError::InvalidState {
161 expected: "Established",
162 actual: "Other",
163 });
164 }
165
166 Ok(0)
169 }
170
171 pub fn close(&mut self) -> Result<(), KernelError> {
173 match self.state {
174 TcpState::Established => {
175 let fin_ack = build_tcp_segment(
176 self.local.port(),
177 self.remote.port(),
178 self.seq_num,
179 self.ack_num,
180 TcpFlags::FIN | TcpFlags::ACK,
181 self.window_size,
182 &[],
183 );
184 let _ = send_tcp_via_ip(self.remote.ip(), &fin_ack);
185 self.seq_num = self.seq_num.wrapping_add(1); self.state = TcpState::FinWait1;
187 Ok(())
188 }
189 TcpState::CloseWait => {
190 let fin_ack = build_tcp_segment(
191 self.local.port(),
192 self.remote.port(),
193 self.seq_num,
194 self.ack_num,
195 TcpFlags::FIN | TcpFlags::ACK,
196 self.window_size,
197 &[],
198 );
199 let _ = send_tcp_via_ip(self.remote.ip(), &fin_ack);
200 self.seq_num = self.seq_num.wrapping_add(1);
201 self.state = TcpState::LastAck;
202 Ok(())
203 }
204 _ => Err(KernelError::InvalidState {
205 expected: "Established or CloseWait",
206 actual: "Other",
207 }),
208 }
209 }
210}
211
212pub fn init() -> Result<(), KernelError> {
214 println!("[TCP] Initializing TCP protocol...");
215 println!("[TCP] TCP initialized");
216 Ok(())
217}
218
219#[allow(dead_code)] fn build_tcp_segment(
230 src_port: u16,
231 dst_port: u16,
232 seq_num: u32,
233 ack_num: u32,
234 flags: u8,
235 window: u16,
236 payload: &[u8],
237) -> Vec<u8> {
238 let data_offset: u8 = 5; let mut seg = Vec::with_capacity(TCP_HEADER_SIZE + payload.len());
240
241 seg.extend_from_slice(&src_port.to_be_bytes());
242 seg.extend_from_slice(&dst_port.to_be_bytes());
243 seg.extend_from_slice(&seq_num.to_be_bytes());
244 seg.extend_from_slice(&ack_num.to_be_bytes());
245 seg.push(data_offset << 4); seg.push(flags);
247 seg.extend_from_slice(&window.to_be_bytes());
248 seg.extend_from_slice(&0u16.to_be_bytes()); seg.extend_from_slice(&0u16.to_be_bytes()); seg.extend_from_slice(payload);
252 seg
253}
254
255#[allow(dead_code)] fn send_tcp_via_ip(dest: super::IpAddress, segment: &[u8]) -> Result<(), KernelError> {
258 super::ip::send(dest, super::ip::IpProtocol::Tcp, segment)
259}
260
261fn process_tcp_state_transition(
266 state: &mut TcpSocketState,
267 flags: TcpFlags,
268 seq_num: u32,
269 _ack_num: u32,
270 payload: &[u8],
271 _src_addr: IpAddress,
272 _src_port: u16,
273) {
274 match state.connection.state {
275 TcpState::SynSent => {
276 if flags.has(TcpFlags::SYN) && flags.has(TcpFlags::ACK) {
278 state.recv_seq = seq_num.wrapping_add(1);
279 state.connection.ack_num = state.recv_seq;
280 state.connection.state = TcpState::Established;
281
282 let ack = build_tcp_segment(
284 state.connection.local.port(),
285 state.connection.remote.port(),
286 state.connection.seq_num,
287 state.connection.ack_num,
288 TcpFlags::ACK,
289 state.connection.window_size,
290 &[],
291 );
292 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
293 }
294 }
295 TcpState::Listen => {
296 }
298 TcpState::SynReceived => {
299 if flags.has(TcpFlags::ACK) {
300 state.connection.state = TcpState::Established;
301 }
302 }
303 TcpState::Established => {
304 if !payload.is_empty() {
306 state.recv_buffer.extend_from_slice(payload);
307 state.recv_seq = seq_num.wrapping_add(payload.len() as u32);
308 state.connection.ack_num = state.recv_seq;
309
310 let ack = build_tcp_segment(
312 state.connection.local.port(),
313 state.connection.remote.port(),
314 state.connection.seq_num,
315 state.connection.ack_num,
316 TcpFlags::ACK,
317 state.connection.window_size,
318 &[],
319 );
320 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
321 }
322
323 if flags.has(TcpFlags::FIN) {
325 state.recv_seq = state.recv_seq.wrapping_add(1);
326 state.connection.ack_num = state.recv_seq;
327 state.connection.state = TcpState::CloseWait;
328
329 let ack = build_tcp_segment(
331 state.connection.local.port(),
332 state.connection.remote.port(),
333 state.connection.seq_num,
334 state.connection.ack_num,
335 TcpFlags::ACK,
336 state.connection.window_size,
337 &[],
338 );
339 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
340 }
341 }
342 TcpState::FinWait1 => {
343 if flags.has(TcpFlags::FIN) && flags.has(TcpFlags::ACK) {
344 state.connection.ack_num = seq_num.wrapping_add(1);
346 state.connection.state = TcpState::TimeWait;
347 let ack = build_tcp_segment(
349 state.connection.local.port(),
350 state.connection.remote.port(),
351 state.connection.seq_num,
352 state.connection.ack_num,
353 TcpFlags::ACK,
354 state.connection.window_size,
355 &[],
356 );
357 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
358 } else if flags.has(TcpFlags::ACK) {
359 state.connection.state = TcpState::FinWait2;
360 }
361 }
362 TcpState::FinWait2 => {
363 if flags.has(TcpFlags::FIN) {
364 state.connection.ack_num = seq_num.wrapping_add(1);
365 state.connection.state = TcpState::TimeWait;
366 let ack = build_tcp_segment(
367 state.connection.local.port(),
368 state.connection.remote.port(),
369 state.connection.seq_num,
370 state.connection.ack_num,
371 TcpFlags::ACK,
372 state.connection.window_size,
373 &[],
374 );
375 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
376 }
377 }
378 TcpState::LastAck => {
379 if flags.has(TcpFlags::ACK) {
380 state.connection.state = TcpState::Closed;
381 }
382 }
383 TcpState::TimeWait => {
384 if flags.has(TcpFlags::FIN) {
386 let ack = build_tcp_segment(
387 state.connection.local.port(),
388 state.connection.remote.port(),
389 state.connection.seq_num,
390 state.connection.ack_num,
391 TcpFlags::ACK,
392 state.connection.window_size,
393 &[],
394 );
395 let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
396 }
397 }
398 _ => {}
399 }
400}
401
402struct TcpSocketState {
408 connection: TcpConnection,
409 send_buffer: Vec<u8>,
410 recv_buffer: Vec<u8>,
411 send_seq: u32,
412 recv_seq: u32,
413}
414
415static TCP_CONNECTIONS: Mutex<BTreeMap<usize, TcpSocketState>> = Mutex::new(BTreeMap::new());
417
418pub fn transmit_data(socket_id: usize, data: &[u8], remote: SocketAddr) {
420 let mut connections = TCP_CONNECTIONS.lock();
421
422 let state = connections.entry(socket_id).or_insert_with(|| {
424 let local = SocketAddr::v4(super::Ipv4Address::UNSPECIFIED, 0);
425 TcpSocketState {
426 connection: TcpConnection::new(local, remote),
427 send_buffer: Vec::new(),
428 recv_buffer: Vec::new(),
429 send_seq: generate_initial_seq(),
430 recv_seq: 0,
431 }
432 });
433
434 state.connection.remote = remote;
436 if state.connection.state == TcpState::Closed {
437 state.connection.state = TcpState::Established;
438 }
439
440 state.send_buffer.extend_from_slice(data);
442
443 let bytes_sent = data.len();
451 state.send_seq = state.send_seq.wrapping_add(bytes_sent as u32);
452 state.send_buffer.clear();
453
454 #[cfg(feature = "net_debug")]
455 println!(
456 "[TCP] Transmitted {} bytes to {:?} (socket {})",
457 bytes_sent, remote, socket_id
458 );
459}
460
461pub fn receive_data(socket_id: usize, buffer: &mut Vec<u8>) -> usize {
463 let mut connections = TCP_CONNECTIONS.lock();
464
465 if let Some(state) = connections.get_mut(&socket_id) {
466 if state.connection.state != TcpState::Established {
467 return 0;
468 }
469
470 let bytes_available = state.recv_buffer.len();
472 if bytes_available > 0 {
473 buffer.extend_from_slice(&state.recv_buffer);
474 state.recv_buffer.clear();
475 state.recv_seq = state.recv_seq.wrapping_add(bytes_available as u32);
476
477 #[cfg(feature = "net_debug")]
478 println!(
479 "[TCP] Received {} bytes from socket {}",
480 bytes_available, socket_id
481 );
482
483 return bytes_available;
484 }
485 }
486
487 0
488}
489
490pub fn close_connection(socket_id: usize) {
492 let mut connections = TCP_CONNECTIONS.lock();
493
494 if let Some(state) = connections.get_mut(&socket_id) {
495 match state.connection.state {
497 TcpState::Established => {
498 state.connection.state = TcpState::FinWait1;
500
501 state.connection.state = TcpState::Closed;
504 }
505 TcpState::CloseWait => {
506 state.connection.state = TcpState::LastAck;
508 state.connection.state = TcpState::Closed;
509 }
510 _ => {
511 state.connection.state = TcpState::Closed;
513 }
514 }
515
516 state.send_buffer.clear();
518 state.recv_buffer.clear();
519 }
520
521 connections.remove(&socket_id);
523
524 #[cfg(feature = "net_debug")]
525 println!("[TCP] Closed connection for socket {}", socket_id);
526}
527
528pub fn process_packet(
533 src_addr: super::IpAddress,
534 _dst_addr: super::IpAddress,
535 data: &[u8],
536) -> Result<(), KernelError> {
537 if data.len() < TCP_HEADER_SIZE {
538 return Err(KernelError::InvalidArgument {
539 name: "tcp_packet",
540 value: "too_short",
541 });
542 }
543
544 let src_port = u16::from_be_bytes([data[0], data[1]]);
546 let dst_port = u16::from_be_bytes([data[2], data[3]]);
547 let seq_num = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
548 let ack_num = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
549 let data_offset = ((data[12] >> 4) * 4) as usize;
550 let flags = TcpFlags::new(data[13]);
551 let _window = u16::from_be_bytes([data[14], data[15]]);
552
553 let payload = if data.len() > data_offset {
555 &data[data_offset..]
556 } else {
557 &[]
558 };
559
560 let mut connections = TCP_CONNECTIONS.lock();
561 let remote = SocketAddr::new(src_addr, src_port);
562
563 for (_socket_id, state) in connections.iter_mut() {
565 if state.connection.remote == remote
566 || (state.connection.state == TcpState::Listen
567 && state.connection.local.port() == dst_port)
568 {
569 if flags.has(TcpFlags::SYN)
571 && !flags.has(TcpFlags::ACK)
572 && state.connection.state == TcpState::Listen
573 {
574 let local_addr = state.connection.local;
575 if let Err(_e) =
576 super::socket::queue_pending_connection(local_addr, remote, seq_num)
577 {
578 #[cfg(feature = "net_debug")]
579 println!("[TCP] Failed to queue connection: {:?}", _e);
580 }
581 return Ok(());
582 }
583
584 process_tcp_state_transition(
586 state, flags, seq_num, ack_num, payload, src_addr, src_port,
587 );
588
589 return Ok(());
590 }
591 }
592
593 if !flags.has(TcpFlags::RST) {
595 let rst = build_tcp_segment(
596 dst_port,
597 src_port,
598 ack_num,
599 seq_num.wrapping_add(payload.len() as u32),
600 TcpFlags::RST | TcpFlags::ACK,
601 0,
602 &[],
603 );
604 let _ = send_tcp_via_ip(src_addr, &rst);
605 }
606
607 Ok(())
608}
609
610fn generate_initial_seq() -> u32 {
612 static COUNTER: core::sync::atomic::AtomicU32 = core::sync::atomic::AtomicU32::new(1000000);
615 COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
616}
617
618pub fn get_stats() -> TcpStats {
620 let connections = TCP_CONNECTIONS.lock();
621 TcpStats {
622 active_connections: connections.len(),
623 total_bytes_sent: 0, total_bytes_recv: 0, retransmissions: 0, }
627}
628
629#[derive(Debug, Clone, Copy, Default)]
631pub struct TcpStats {
632 pub active_connections: usize,
633 pub total_bytes_sent: u64,
634 pub total_bytes_recv: u64,
635 pub retransmissions: u64,
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::net::Ipv4Address;
642
643 #[test]
644 fn test_tcp_flags() {
645 let mut flags = TcpFlags::new(0);
646 flags.set(TcpFlags::SYN);
647 assert!(flags.has(TcpFlags::SYN));
648 assert!(!flags.has(TcpFlags::ACK));
649 }
650
651 #[test]
652 fn test_tcp_connection() {
653 let local = SocketAddr::v4(Ipv4Address::LOCALHOST, 8080);
654 let remote = SocketAddr::v4(Ipv4Address::new(192, 168, 1, 1), 80);
655 let conn = TcpConnection::new(local, remote);
656
657 assert_eq!(conn.state, TcpState::Closed);
658 }
659}