1use alloc::vec::Vec;
7
8use spin::Mutex;
9
10use super::{IpAddress, Ipv4Address};
11use crate::error::KernelError;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum IpProtocol {
17 Icmp = 1,
18 Tcp = 6,
19 Udp = 17,
20}
21
22#[derive(Debug, Clone)]
24pub struct Ipv4Header {
25 pub version: u8,
26 pub ihl: u8,
27 pub tos: u8,
28 pub total_length: u16,
29 pub identification: u16,
30 pub flags: u8,
31 pub fragment_offset: u16,
32 pub ttl: u8,
33 pub protocol: u8,
34 pub checksum: u16,
35 pub source: Ipv4Address,
36 pub destination: Ipv4Address,
37}
38
39impl Ipv4Header {
40 pub const MIN_SIZE: usize = 20;
41
42 pub fn new(src: Ipv4Address, dst: Ipv4Address, protocol: IpProtocol) -> Self {
43 Self {
44 version: 4,
45 ihl: 5, tos: 0,
47 total_length: 0,
48 identification: 0,
49 flags: 0,
50 fragment_offset: 0,
51 ttl: 64,
52 protocol: protocol as u8,
53 checksum: 0,
54 source: src,
55 destination: dst,
56 }
57 }
58
59 pub fn to_bytes(&self) -> [u8; 20] {
60 let mut bytes = [0u8; 20];
61
62 bytes[0] = (self.version << 4) | self.ihl;
63 bytes[1] = self.tos;
64 bytes[2..4].copy_from_slice(&self.total_length.to_be_bytes());
65 bytes[4..6].copy_from_slice(&self.identification.to_be_bytes());
66 bytes[6] = (self.flags << 5) | ((self.fragment_offset >> 8) as u8);
67 bytes[7] = (self.fragment_offset & 0xFF) as u8;
68 bytes[8] = self.ttl;
69 bytes[9] = self.protocol;
70 bytes[10..12].copy_from_slice(&self.checksum.to_be_bytes());
71 bytes[12..16].copy_from_slice(&self.source.0);
72 bytes[16..20].copy_from_slice(&self.destination.0);
73
74 bytes
75 }
76
77 pub fn from_bytes(bytes: &[u8]) -> Result<Self, KernelError> {
78 if bytes.len() < Self::MIN_SIZE {
79 return Err(KernelError::InvalidArgument {
80 name: "ip_header",
81 value: "too_short",
82 });
83 }
84
85 let version = bytes[0] >> 4;
86 if version != 4 {
87 return Err(KernelError::InvalidArgument {
88 name: "ip_version",
89 value: "not_ipv4",
90 });
91 }
92
93 Ok(Self {
94 version,
95 ihl: bytes[0] & 0x0F,
96 tos: bytes[1],
97 total_length: u16::from_be_bytes([bytes[2], bytes[3]]),
98 identification: u16::from_be_bytes([bytes[4], bytes[5]]),
99 flags: bytes[6] >> 5,
100 fragment_offset: u16::from_be_bytes([bytes[6] & 0x1F, bytes[7]]),
101 ttl: bytes[8],
102 protocol: bytes[9],
103 checksum: u16::from_be_bytes([bytes[10], bytes[11]]),
104 source: Ipv4Address([bytes[12], bytes[13], bytes[14], bytes[15]]),
105 destination: Ipv4Address([bytes[16], bytes[17], bytes[18], bytes[19]]),
106 })
107 }
108
109 pub fn calculate_checksum(&mut self) {
111 self.checksum = 0;
112 let bytes = self.to_bytes();
113
114 let mut sum: u32 = 0;
115 for i in 0..10 {
116 sum += u16::from_be_bytes([bytes[i * 2], bytes[i * 2 + 1]]) as u32;
117 }
118
119 while sum >> 16 != 0 {
120 sum = (sum & 0xFFFF) + (sum >> 16);
121 }
122
123 self.checksum = !(sum as u16);
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct RouteEntry {
130 pub destination: Ipv4Address,
131 pub netmask: Ipv4Address,
132 pub gateway: Option<Ipv4Address>,
133 pub interface: usize,
134}
135
136#[allow(dead_code)] #[derive(Debug, Clone, Copy)]
139pub struct InterfaceConfig {
140 pub ip_addr: Ipv4Address,
142 pub subnet_mask: Ipv4Address,
144 pub gateway: Option<Ipv4Address>,
146}
147
148static INTERFACE_CONFIG: Mutex<InterfaceConfig> = Mutex::new(InterfaceConfig {
150 ip_addr: Ipv4Address::ANY,
151 subnet_mask: Ipv4Address::ANY,
152 gateway: None,
153});
154
155pub fn get_interface_ip() -> Ipv4Address {
157 INTERFACE_CONFIG.lock().ip_addr
158}
159
160pub fn get_interface_config() -> InterfaceConfig {
162 *INTERFACE_CONFIG.lock()
163}
164
165pub fn set_interface_config(ip: Ipv4Address, mask: Ipv4Address, gw: Option<Ipv4Address>) {
167 let mut config = INTERFACE_CONFIG.lock();
168 config.ip_addr = ip;
169 config.subnet_mask = mask;
170 config.gateway = gw;
171
172 println!(
173 "[IP] Interface configured: {}.{}.{}.{}/{}.{}.{}.{}",
174 ip.0[0], ip.0[1], ip.0[2], ip.0[3], mask.0[0], mask.0[1], mask.0[2], mask.0[3],
175 );
176
177 if let Some(gateway) = gw {
178 println!(
179 "[IP] Gateway: {}.{}.{}.{}",
180 gateway.0[0], gateway.0[1], gateway.0[2], gateway.0[3],
181 );
182 }
183}
184
185static ROUTES: Mutex<Vec<RouteEntry>> = Mutex::new(Vec::new());
187
188pub fn add_route(entry: RouteEntry) {
190 ROUTES.lock().push(entry);
191}
192
193pub fn lookup_route(dest: Ipv4Address) -> Option<RouteEntry> {
195 let routes = ROUTES.lock();
196 for route in routes.iter() {
197 let dest_masked = dest.to_u32() & route.netmask.to_u32();
198 let route_masked = route.destination.to_u32() & route.netmask.to_u32();
199
200 if dest_masked == route_masked {
201 return Some(route.clone());
202 }
203 }
204 None
205}
206
207pub fn get_routes() -> Vec<RouteEntry> {
209 ROUTES.lock().clone()
210}
211
212static IP_ID_COUNTER: core::sync::atomic::AtomicU16 = core::sync::atomic::AtomicU16::new(1);
214
215pub fn send(dest: IpAddress, protocol: IpProtocol, data: &[u8]) -> Result<(), KernelError> {
220 match dest {
221 IpAddress::V4(dest_v4) => {
222 let src = get_interface_ip();
224
225 let mut header = Ipv4Header::new(src, dest_v4, protocol);
226 header.total_length = (Ipv4Header::MIN_SIZE + data.len()) as u16;
227 header.identification =
228 IP_ID_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
229 header.flags = 0x02; header.calculate_checksum();
231
232 let header_bytes = header.to_bytes();
234 let mut ip_packet = Vec::with_capacity(header_bytes.len() + data.len());
235 ip_packet.extend_from_slice(&header_bytes);
236 ip_packet.extend_from_slice(data);
237
238 let dst_mac = if dest_v4 == Ipv4Address::BROADCAST {
240 super::MacAddress::BROADCAST
241 } else {
242 super::arp::resolve(dest_v4).unwrap_or_else(|| {
244 super::arp::send_arp_request(dest_v4);
245 super::MacAddress::BROADCAST
246 })
247 };
248
249 let src_mac = super::device::with_device("eth0", |dev| dev.mac_address())
250 .unwrap_or(super::MacAddress::ZERO);
251
252 let frame = super::ethernet::construct_frame(
254 dst_mac,
255 src_mac,
256 super::ethernet::ETHERTYPE_IPV4,
257 &ip_packet,
258 );
259
260 let pkt = super::Packet::from_bytes(&frame);
262 super::device::with_device_mut("eth0", |dev| {
263 let _ = dev.transmit(&pkt);
264 });
265
266 super::update_stats_tx(header.total_length as usize);
267
268 Ok(())
269 }
270 IpAddress::V6(dest_v6) => {
271 let src = super::ipv6::select_source_address(&dest_v6)
273 .unwrap_or(super::Ipv6Address::UNSPECIFIED);
274 let next_header = match protocol {
275 IpProtocol::Tcp => super::ipv6::NEXT_HEADER_TCP,
276 IpProtocol::Udp => super::ipv6::NEXT_HEADER_UDP,
277 IpProtocol::Icmp => super::ipv6::NEXT_HEADER_ICMPV6,
278 };
279 super::ipv6::send(&src, &dest_v6, next_header, data)
280 }
281 }
282}
283
284pub fn init() -> Result<(), KernelError> {
286 println!("[IP] Initializing IP layer...");
287
288 add_route(RouteEntry {
290 destination: Ipv4Address::new(127, 0, 0, 0),
291 netmask: Ipv4Address::new(255, 0, 0, 0),
292 gateway: None,
293 interface: 0,
294 });
295
296 println!("[IP] IP layer initialized");
297 Ok(())
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_ipv4_header() {
306 let src = Ipv4Address::new(192, 168, 1, 1);
307 let dst = Ipv4Address::new(192, 168, 1, 2);
308 let header = Ipv4Header::new(src, dst, IpProtocol::Tcp);
309
310 assert_eq!(header.version, 4);
311 assert_eq!(header.protocol, 6);
312 assert_eq!(header.source, src);
313 assert_eq!(header.destination, dst);
314 }
315
316 #[test]
317 fn test_ipv4_header_roundtrip() {
318 let src = Ipv4Address::new(10, 0, 0, 1);
319 let dst = Ipv4Address::new(10, 0, 0, 2);
320 let mut header = Ipv4Header::new(src, dst, IpProtocol::Udp);
321 header.calculate_checksum();
322
323 let bytes = header.to_bytes();
324 let parsed = Ipv4Header::from_bytes(&bytes).unwrap();
325
326 assert_eq!(parsed.source, src);
327 assert_eq!(parsed.destination, dst);
328 assert_eq!(parsed.protocol, 17);
329 }
330}