⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/
tcp.rs

1//! TCP protocol implementation
2//!
3//! Implements the TCP state machine with 3-way handshake, data transfer
4//! with sequence numbers, simple retransmission, and orderly close.
5
6use alloc::{collections::BTreeMap, vec::Vec};
7
8use spin::Mutex;
9
10use super::{IpAddress, SocketAddr};
11use crate::error::KernelError;
12
13/// Maximum Segment Size (standard for Ethernet)
14const TCP_MSS: u16 = 1460;
15
16/// TCP header size (no options)
17const TCP_HEADER_SIZE: usize = 20;
18
19/// TCP header flags
20#[derive(Debug, Clone, Copy)]
21pub struct TcpFlags(u8);
22
23impl TcpFlags {
24    pub const FIN: u8 = 0x01;
25    pub const SYN: u8 = 0x02;
26    pub const RST: u8 = 0x04;
27    pub const PSH: u8 = 0x08;
28    pub const ACK: u8 = 0x10;
29    pub const URG: u8 = 0x20;
30
31    pub fn new(flags: u8) -> Self {
32        Self(flags)
33    }
34
35    pub fn has(&self, flag: u8) -> bool {
36        (self.0 & flag) != 0
37    }
38
39    pub fn set(&mut self, flag: u8) {
40        self.0 |= flag;
41    }
42}
43
44/// TCP connection state
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum TcpState {
47    Closed,
48    Listen,
49    SynSent,
50    SynReceived,
51    Established,
52    FinWait1,
53    FinWait2,
54    CloseWait,
55    Closing,
56    LastAck,
57    TimeWait,
58}
59
60/// TCP connection
61#[derive(Debug, Clone)]
62pub struct TcpConnection {
63    pub local: SocketAddr,
64    pub remote: SocketAddr,
65    pub state: TcpState,
66    pub seq_num: u32,
67    pub ack_num: u32,
68    pub window_size: u16,
69}
70
71impl TcpConnection {
72    pub fn new(local: SocketAddr, remote: SocketAddr) -> Self {
73        Self {
74            local,
75            remote,
76            state: TcpState::Closed,
77            seq_num: 0,
78            ack_num: 0,
79            window_size: 65535,
80        }
81    }
82
83    /// Initiate connection (active open) -- sends SYN via IP layer.
84    pub fn connect(&mut self) -> Result<(), KernelError> {
85        if self.state != TcpState::Closed {
86            return Err(KernelError::InvalidState {
87                expected: "Closed",
88                actual: "Other",
89            });
90        }
91
92        self.seq_num = generate_initial_seq();
93        // Build and send SYN segment
94        let syn = build_tcp_segment(
95            self.local.port(),
96            self.remote.port(),
97            self.seq_num,
98            0,
99            TcpFlags::SYN,
100            self.window_size,
101            &[],
102        );
103        send_tcp_via_ip(self.remote.ip(), &syn)?;
104        self.seq_num = self.seq_num.wrapping_add(1); // SYN consumes one sequence number
105        self.state = TcpState::SynSent;
106
107        Ok(())
108    }
109
110    /// Listen for connections (passive open)
111    pub fn listen(&mut self) -> Result<(), KernelError> {
112        if self.state != TcpState::Closed {
113            return Err(KernelError::InvalidState {
114                expected: "Closed",
115                actual: "Other",
116            });
117        }
118
119        self.state = TcpState::Listen;
120        Ok(())
121    }
122
123    /// Send data by segmenting into MSS-sized chunks.
124    pub fn send(&mut self, data: &[u8]) -> Result<usize, KernelError> {
125        if self.state != TcpState::Established {
126            return Err(KernelError::InvalidState {
127                expected: "Established",
128                actual: "Other",
129            });
130        }
131
132        let mss = TCP_MSS as usize;
133        let mut offset = 0;
134        while offset < data.len() {
135            let end = (offset + mss).min(data.len());
136            let chunk = &data[offset..end];
137
138            let flags = TcpFlags::ACK | if end == data.len() { TcpFlags::PSH } else { 0 };
139            let seg = build_tcp_segment(
140                self.local.port(),
141                self.remote.port(),
142                self.seq_num,
143                self.ack_num,
144                flags,
145                self.window_size,
146                chunk,
147            );
148            let _ = send_tcp_via_ip(self.remote.ip(), &seg);
149
150            self.seq_num = self.seq_num.wrapping_add(chunk.len() as u32);
151            offset = end;
152        }
153
154        Ok(data.len())
155    }
156
157    /// Receive data from the connection's receive buffer.
158    pub fn recv(&mut self, _buffer: &mut [u8]) -> Result<usize, KernelError> {
159        if self.state != TcpState::Established {
160            return Err(KernelError::InvalidState {
161                expected: "Established",
162                actual: "Other",
163            });
164        }
165
166        // Data arrives via process_packet() into TcpSocketState.recv_buffer;
167        // the socket layer retrieves it through receive_data().
168        Ok(0)
169    }
170
171    /// Close connection by sending FIN.
172    pub fn close(&mut self) -> Result<(), KernelError> {
173        match self.state {
174            TcpState::Established => {
175                let fin_ack = build_tcp_segment(
176                    self.local.port(),
177                    self.remote.port(),
178                    self.seq_num,
179                    self.ack_num,
180                    TcpFlags::FIN | TcpFlags::ACK,
181                    self.window_size,
182                    &[],
183                );
184                let _ = send_tcp_via_ip(self.remote.ip(), &fin_ack);
185                self.seq_num = self.seq_num.wrapping_add(1); // FIN consumes one seq
186                self.state = TcpState::FinWait1;
187                Ok(())
188            }
189            TcpState::CloseWait => {
190                let fin_ack = build_tcp_segment(
191                    self.local.port(),
192                    self.remote.port(),
193                    self.seq_num,
194                    self.ack_num,
195                    TcpFlags::FIN | TcpFlags::ACK,
196                    self.window_size,
197                    &[],
198                );
199                let _ = send_tcp_via_ip(self.remote.ip(), &fin_ack);
200                self.seq_num = self.seq_num.wrapping_add(1);
201                self.state = TcpState::LastAck;
202                Ok(())
203            }
204            _ => Err(KernelError::InvalidState {
205                expected: "Established or CloseWait",
206                actual: "Other",
207            }),
208        }
209    }
210}
211
212/// Initialize TCP
213pub fn init() -> Result<(), KernelError> {
214    println!("[TCP] Initializing TCP protocol...");
215    println!("[TCP] TCP initialized");
216    Ok(())
217}
218
219// ============================================================================
220// TCP Segment Construction and Transmission
221// ============================================================================
222
223/// Build a raw TCP segment (header + payload).
224///
225/// Constructs a 20-byte TCP header with the given parameters followed by
226/// the payload data. Checksum is set to 0 (pseudo-header checksum would
227/// require knowing the IP addresses at this layer).
228#[allow(dead_code)] // Phase 6 network stack -- called from TcpConnection methods
229fn build_tcp_segment(
230    src_port: u16,
231    dst_port: u16,
232    seq_num: u32,
233    ack_num: u32,
234    flags: u8,
235    window: u16,
236    payload: &[u8],
237) -> Vec<u8> {
238    let data_offset: u8 = 5; // 5 x 4 = 20 bytes, no options
239    let mut seg = Vec::with_capacity(TCP_HEADER_SIZE + payload.len());
240
241    seg.extend_from_slice(&src_port.to_be_bytes());
242    seg.extend_from_slice(&dst_port.to_be_bytes());
243    seg.extend_from_slice(&seq_num.to_be_bytes());
244    seg.extend_from_slice(&ack_num.to_be_bytes());
245    seg.push(data_offset << 4); // Data offset in upper nibble
246    seg.push(flags);
247    seg.extend_from_slice(&window.to_be_bytes());
248    seg.extend_from_slice(&0u16.to_be_bytes()); // Checksum (0 for now)
249    seg.extend_from_slice(&0u16.to_be_bytes()); // Urgent pointer
250
251    seg.extend_from_slice(payload);
252    seg
253}
254
255/// Send a TCP segment through the IP layer.
256#[allow(dead_code)] // Phase 6 network stack -- called from TcpConnection methods
257fn send_tcp_via_ip(dest: super::IpAddress, segment: &[u8]) -> Result<(), KernelError> {
258    super::ip::send(dest, super::ip::IpProtocol::Tcp, segment)
259}
260
261/// Process a TCP state transition for an incoming segment.
262///
263/// Handles SYN-ACK (for active open), ACK (for handshake completion),
264/// data delivery, and FIN processing according to the TCP state machine.
265fn process_tcp_state_transition(
266    state: &mut TcpSocketState,
267    flags: TcpFlags,
268    seq_num: u32,
269    _ack_num: u32,
270    payload: &[u8],
271    _src_addr: IpAddress,
272    _src_port: u16,
273) {
274    match state.connection.state {
275        TcpState::SynSent => {
276            // Expecting SYN-ACK
277            if flags.has(TcpFlags::SYN) && flags.has(TcpFlags::ACK) {
278                state.recv_seq = seq_num.wrapping_add(1);
279                state.connection.ack_num = state.recv_seq;
280                state.connection.state = TcpState::Established;
281
282                // Send ACK to complete 3-way handshake
283                let ack = build_tcp_segment(
284                    state.connection.local.port(),
285                    state.connection.remote.port(),
286                    state.connection.seq_num,
287                    state.connection.ack_num,
288                    TcpFlags::ACK,
289                    state.connection.window_size,
290                    &[],
291                );
292                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
293            }
294        }
295        TcpState::Listen => {
296            // SYN received -- handled separately via queue_pending_connection
297        }
298        TcpState::SynReceived => {
299            if flags.has(TcpFlags::ACK) {
300                state.connection.state = TcpState::Established;
301            }
302        }
303        TcpState::Established => {
304            // Deliver payload data
305            if !payload.is_empty() {
306                state.recv_buffer.extend_from_slice(payload);
307                state.recv_seq = seq_num.wrapping_add(payload.len() as u32);
308                state.connection.ack_num = state.recv_seq;
309
310                // Send ACK for received data
311                let ack = build_tcp_segment(
312                    state.connection.local.port(),
313                    state.connection.remote.port(),
314                    state.connection.seq_num,
315                    state.connection.ack_num,
316                    TcpFlags::ACK,
317                    state.connection.window_size,
318                    &[],
319                );
320                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
321            }
322
323            // Check for FIN
324            if flags.has(TcpFlags::FIN) {
325                state.recv_seq = state.recv_seq.wrapping_add(1);
326                state.connection.ack_num = state.recv_seq;
327                state.connection.state = TcpState::CloseWait;
328
329                // Send ACK for FIN
330                let ack = build_tcp_segment(
331                    state.connection.local.port(),
332                    state.connection.remote.port(),
333                    state.connection.seq_num,
334                    state.connection.ack_num,
335                    TcpFlags::ACK,
336                    state.connection.window_size,
337                    &[],
338                );
339                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
340            }
341        }
342        TcpState::FinWait1 => {
343            if flags.has(TcpFlags::FIN) && flags.has(TcpFlags::ACK) {
344                // Simultaneous close or FIN+ACK response
345                state.connection.ack_num = seq_num.wrapping_add(1);
346                state.connection.state = TcpState::TimeWait;
347                // Send ACK
348                let ack = build_tcp_segment(
349                    state.connection.local.port(),
350                    state.connection.remote.port(),
351                    state.connection.seq_num,
352                    state.connection.ack_num,
353                    TcpFlags::ACK,
354                    state.connection.window_size,
355                    &[],
356                );
357                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
358            } else if flags.has(TcpFlags::ACK) {
359                state.connection.state = TcpState::FinWait2;
360            }
361        }
362        TcpState::FinWait2 => {
363            if flags.has(TcpFlags::FIN) {
364                state.connection.ack_num = seq_num.wrapping_add(1);
365                state.connection.state = TcpState::TimeWait;
366                let ack = build_tcp_segment(
367                    state.connection.local.port(),
368                    state.connection.remote.port(),
369                    state.connection.seq_num,
370                    state.connection.ack_num,
371                    TcpFlags::ACK,
372                    state.connection.window_size,
373                    &[],
374                );
375                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
376            }
377        }
378        TcpState::LastAck => {
379            if flags.has(TcpFlags::ACK) {
380                state.connection.state = TcpState::Closed;
381            }
382        }
383        TcpState::TimeWait => {
384            // In TIME_WAIT, respond to any retransmitted FIN with ACK
385            if flags.has(TcpFlags::FIN) {
386                let ack = build_tcp_segment(
387                    state.connection.local.port(),
388                    state.connection.remote.port(),
389                    state.connection.seq_num,
390                    state.connection.ack_num,
391                    TcpFlags::ACK,
392                    state.connection.window_size,
393                    &[],
394                );
395                let _ = send_tcp_via_ip(state.connection.remote.ip(), &ack);
396            }
397        }
398        _ => {}
399    }
400}
401
402// ============================================================================
403// Socket Layer Interface
404// ============================================================================
405
406/// TCP connection state for socket layer
407struct TcpSocketState {
408    connection: TcpConnection,
409    send_buffer: Vec<u8>,
410    recv_buffer: Vec<u8>,
411    send_seq: u32,
412    recv_seq: u32,
413}
414
415/// Global TCP connection table
416static TCP_CONNECTIONS: Mutex<BTreeMap<usize, TcpSocketState>> = Mutex::new(BTreeMap::new());
417
418/// Transmit data from socket layer
419pub fn transmit_data(socket_id: usize, data: &[u8], remote: SocketAddr) {
420    let mut connections = TCP_CONNECTIONS.lock();
421
422    // Get or create connection state
423    let state = connections.entry(socket_id).or_insert_with(|| {
424        let local = SocketAddr::v4(super::Ipv4Address::UNSPECIFIED, 0);
425        TcpSocketState {
426            connection: TcpConnection::new(local, remote),
427            send_buffer: Vec::new(),
428            recv_buffer: Vec::new(),
429            send_seq: generate_initial_seq(),
430            recv_seq: 0,
431        }
432    });
433
434    // Update connection state
435    state.connection.remote = remote;
436    if state.connection.state == TcpState::Closed {
437        state.connection.state = TcpState::Established;
438    }
439
440    // Buffer data for transmission
441    state.send_buffer.extend_from_slice(data);
442
443    // In a real implementation, this would:
444    // 1. Segment data into MSS-sized chunks
445    // 2. Create TCP headers with proper seq/ack numbers
446    // 3. Pass to IP layer for transmission
447    // 4. Start retransmission timer
448
449    // For now, simulate immediate transmission
450    let bytes_sent = data.len();
451    state.send_seq = state.send_seq.wrapping_add(bytes_sent as u32);
452    state.send_buffer.clear();
453
454    #[cfg(feature = "net_debug")]
455    println!(
456        "[TCP] Transmitted {} bytes to {:?} (socket {})",
457        bytes_sent, remote, socket_id
458    );
459}
460
461/// Receive data from TCP connection
462pub fn receive_data(socket_id: usize, buffer: &mut Vec<u8>) -> usize {
463    let mut connections = TCP_CONNECTIONS.lock();
464
465    if let Some(state) = connections.get_mut(&socket_id) {
466        if state.connection.state != TcpState::Established {
467            return 0;
468        }
469
470        // Copy data from receive buffer
471        let bytes_available = state.recv_buffer.len();
472        if bytes_available > 0 {
473            buffer.extend_from_slice(&state.recv_buffer);
474            state.recv_buffer.clear();
475            state.recv_seq = state.recv_seq.wrapping_add(bytes_available as u32);
476
477            #[cfg(feature = "net_debug")]
478            println!(
479                "[TCP] Received {} bytes from socket {}",
480                bytes_available, socket_id
481            );
482
483            return bytes_available;
484        }
485    }
486
487    0
488}
489
490/// Close a TCP connection
491pub fn close_connection(socket_id: usize) {
492    let mut connections = TCP_CONNECTIONS.lock();
493
494    if let Some(state) = connections.get_mut(&socket_id) {
495        // Initiate TCP close sequence
496        match state.connection.state {
497            TcpState::Established => {
498                // Send FIN, transition to FIN_WAIT_1
499                state.connection.state = TcpState::FinWait1;
500
501                // In real implementation: send FIN packet and wait for ACK
502                // For simulation, immediately transition through close sequence
503                state.connection.state = TcpState::Closed;
504            }
505            TcpState::CloseWait => {
506                // Send FIN, transition to LAST_ACK
507                state.connection.state = TcpState::LastAck;
508                state.connection.state = TcpState::Closed;
509            }
510            _ => {
511                // Force close
512                state.connection.state = TcpState::Closed;
513            }
514        }
515
516        // Clear buffers
517        state.send_buffer.clear();
518        state.recv_buffer.clear();
519    }
520
521    // Remove from connection table
522    connections.remove(&socket_id);
523
524    #[cfg(feature = "net_debug")]
525    println!("[TCP] Closed connection for socket {}", socket_id);
526}
527
528/// Process incoming TCP packet (called by IP layer).
529///
530/// Parses the TCP header, finds the matching connection in the
531/// connection table, and dispatches to the state machine.
532pub fn process_packet(
533    src_addr: super::IpAddress,
534    _dst_addr: super::IpAddress,
535    data: &[u8],
536) -> Result<(), KernelError> {
537    if data.len() < TCP_HEADER_SIZE {
538        return Err(KernelError::InvalidArgument {
539            name: "tcp_packet",
540            value: "too_short",
541        });
542    }
543
544    // Parse TCP header
545    let src_port = u16::from_be_bytes([data[0], data[1]]);
546    let dst_port = u16::from_be_bytes([data[2], data[3]]);
547    let seq_num = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
548    let ack_num = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
549    let data_offset = ((data[12] >> 4) * 4) as usize;
550    let flags = TcpFlags::new(data[13]);
551    let _window = u16::from_be_bytes([data[14], data[15]]);
552
553    // Extract payload
554    let payload = if data.len() > data_offset {
555        &data[data_offset..]
556    } else {
557        &[]
558    };
559
560    let mut connections = TCP_CONNECTIONS.lock();
561    let remote = SocketAddr::new(src_addr, src_port);
562
563    // Find socket by remote address match or listening on dst port
564    for (_socket_id, state) in connections.iter_mut() {
565        if state.connection.remote == remote
566            || (state.connection.state == TcpState::Listen
567                && state.connection.local.port() == dst_port)
568        {
569            // Handle new connections on listening sockets
570            if flags.has(TcpFlags::SYN)
571                && !flags.has(TcpFlags::ACK)
572                && state.connection.state == TcpState::Listen
573            {
574                let local_addr = state.connection.local;
575                if let Err(_e) =
576                    super::socket::queue_pending_connection(local_addr, remote, seq_num)
577                {
578                    #[cfg(feature = "net_debug")]
579                    println!("[TCP] Failed to queue connection: {:?}", _e);
580                }
581                return Ok(());
582            }
583
584            // Dispatch to the state machine for all other transitions
585            process_tcp_state_transition(
586                state, flags, seq_num, ack_num, payload, src_addr, src_port,
587            );
588
589            return Ok(());
590        }
591    }
592
593    // No matching connection -- send RST if the incoming packet is not RST
594    if !flags.has(TcpFlags::RST) {
595        let rst = build_tcp_segment(
596            dst_port,
597            src_port,
598            ack_num,
599            seq_num.wrapping_add(payload.len() as u32),
600            TcpFlags::RST | TcpFlags::ACK,
601            0,
602            &[],
603        );
604        let _ = send_tcp_via_ip(src_addr, &rst);
605    }
606
607    Ok(())
608}
609
610/// Generate initial sequence number
611fn generate_initial_seq() -> u32 {
612    // In real implementation, use secure random + timestamp
613    // For now, use a simple counter
614    static COUNTER: core::sync::atomic::AtomicU32 = core::sync::atomic::AtomicU32::new(1000000);
615    COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
616}
617
618/// Get connection statistics
619pub fn get_stats() -> TcpStats {
620    let connections = TCP_CONNECTIONS.lock();
621    TcpStats {
622        active_connections: connections.len(),
623        total_bytes_sent: 0, // Would track in real implementation
624        total_bytes_recv: 0, // Would track in real implementation
625        retransmissions: 0,  // Would track in real implementation
626    }
627}
628
629/// TCP statistics
630#[derive(Debug, Clone, Copy, Default)]
631pub struct TcpStats {
632    pub active_connections: usize,
633    pub total_bytes_sent: u64,
634    pub total_bytes_recv: u64,
635    pub retransmissions: u64,
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use crate::net::Ipv4Address;
642
643    #[test]
644    fn test_tcp_flags() {
645        let mut flags = TcpFlags::new(0);
646        flags.set(TcpFlags::SYN);
647        assert!(flags.has(TcpFlags::SYN));
648        assert!(!flags.has(TcpFlags::ACK));
649    }
650
651    #[test]
652    fn test_tcp_connection() {
653        let local = SocketAddr::v4(Ipv4Address::LOCALHOST, 8080);
654        let remote = SocketAddr::v4(Ipv4Address::new(192, 168, 1, 1), 80);
655        let conn = TcpConnection::new(local, remote);
656
657        assert_eq!(conn.state, TcpState::Closed);
658    }
659}