⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/services/lb/
l4.rs

1//! L4 (Transport Layer) Load Balancer
2//!
3//! Provides TCP/UDP load balancing with multiple algorithms including
4//! round-robin, least connections, weighted round-robin, random, and
5//! IP hash.
6
7#![allow(dead_code)]
8
9use alloc::{collections::BTreeMap, string::String, vec::Vec};
10
11// ---------------------------------------------------------------------------
12// Backend
13// ---------------------------------------------------------------------------
14
15/// A backend server in a virtual IP pool.
16#[derive(Debug, Clone)]
17pub struct Backend {
18    /// Backend address (IP string).
19    pub address: String,
20    /// Backend port.
21    pub port: u16,
22    /// Weight for weighted algorithms (1-100).
23    pub weight: u32,
24    /// Whether the backend is healthy.
25    pub healthy: bool,
26    /// Number of active connections.
27    pub active_connections: u32,
28    /// Total requests served.
29    pub total_requests: u64,
30}
31
32impl Backend {
33    /// Create a new backend.
34    pub fn new(address: String, port: u16, weight: u32) -> Self {
35        Backend {
36            address,
37            port,
38            weight,
39            healthy: true,
40            active_connections: 0,
41            total_requests: 0,
42        }
43    }
44}
45
46// ---------------------------------------------------------------------------
47// Load Balancing Algorithm
48// ---------------------------------------------------------------------------
49
50/// Load balancing algorithm.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum LbAlgorithm {
53    /// Simple round-robin.
54    #[default]
55    RoundRobin,
56    /// Least active connections.
57    LeastConnections,
58    /// Weighted round-robin.
59    WeightedRoundRobin,
60    /// Pseudo-random selection.
61    Random,
62    /// Hash of client IP.
63    IpHash,
64}
65
66// ---------------------------------------------------------------------------
67// Virtual IP
68// ---------------------------------------------------------------------------
69
70/// A virtual IP (VIP) with its backend pool.
71#[derive(Debug, Clone)]
72pub struct VirtualIp {
73    /// VIP address.
74    pub vip_addr: String,
75    /// VIP port.
76    pub vip_port: u16,
77    /// Backend servers.
78    pub backends: Vec<Backend>,
79    /// Load balancing algorithm.
80    pub algorithm: LbAlgorithm,
81}
82
83impl VirtualIp {
84    /// Get the number of healthy backends.
85    pub fn healthy_count(&self) -> usize {
86        self.backends.iter().filter(|b| b.healthy).count()
87    }
88}
89
90// ---------------------------------------------------------------------------
91// L4 Error
92// ---------------------------------------------------------------------------
93
94/// L4 load balancer error.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum L4Error {
97    /// VIP not found.
98    VipNotFound(String),
99    /// VIP already exists.
100    VipAlreadyExists(String),
101    /// No healthy backends.
102    NoHealthyBackend,
103    /// Backend not found.
104    BackendNotFound(String),
105}
106
107// ---------------------------------------------------------------------------
108// L4 Load Balancer
109// ---------------------------------------------------------------------------
110
111/// L4 Load Balancer implementation.
112#[derive(Debug)]
113pub struct L4LoadBalancer {
114    /// Virtual IPs keyed by "addr:port".
115    vips: BTreeMap<String, VirtualIp>,
116    /// Round-robin counter (per VIP is cleaner but global works for
117    /// simplicity).
118    rr_counter: u64,
119    /// Pseudo-random state.
120    random_state: u64,
121}
122
123impl Default for L4LoadBalancer {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl L4LoadBalancer {
130    /// Create a new L4 load balancer.
131    pub fn new() -> Self {
132        L4LoadBalancer {
133            vips: BTreeMap::new(),
134            rr_counter: 0,
135            random_state: 0x12345678,
136        }
137    }
138
139    /// VIP key string.
140    fn vip_key(addr: &str, port: u16) -> String {
141        alloc::format!("{}:{}", addr, port)
142    }
143
144    /// Add a virtual IP.
145    pub fn add_vip(
146        &mut self,
147        vip_addr: String,
148        vip_port: u16,
149        algorithm: LbAlgorithm,
150    ) -> Result<(), L4Error> {
151        let key = Self::vip_key(&vip_addr, vip_port);
152        if self.vips.contains_key(&key) {
153            return Err(L4Error::VipAlreadyExists(key));
154        }
155        self.vips.insert(
156            key,
157            VirtualIp {
158                vip_addr,
159                vip_port,
160                backends: Vec::new(),
161                algorithm,
162            },
163        );
164        Ok(())
165    }
166
167    /// Remove a virtual IP.
168    pub fn remove_vip(&mut self, vip_addr: &str, vip_port: u16) -> Result<(), L4Error> {
169        let key = Self::vip_key(vip_addr, vip_port);
170        self.vips
171            .remove(&key)
172            .map(|_| ())
173            .ok_or(L4Error::VipNotFound(key))
174    }
175
176    /// Add a backend to a VIP.
177    pub fn add_backend(
178        &mut self,
179        vip_addr: &str,
180        vip_port: u16,
181        backend: Backend,
182    ) -> Result<(), L4Error> {
183        let key = Self::vip_key(vip_addr, vip_port);
184        let vip = self.vips.get_mut(&key).ok_or(L4Error::VipNotFound(key))?;
185        vip.backends.push(backend);
186        Ok(())
187    }
188
189    /// Remove a backend from a VIP by address.
190    pub fn remove_backend(
191        &mut self,
192        vip_addr: &str,
193        vip_port: u16,
194        backend_addr: &str,
195    ) -> Result<(), L4Error> {
196        let key = Self::vip_key(vip_addr, vip_port);
197        let vip = self
198            .vips
199            .get_mut(&key)
200            .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
201
202        let before = vip.backends.len();
203        vip.backends.retain(|b| b.address != backend_addr);
204        if vip.backends.len() == before {
205            return Err(L4Error::BackendNotFound(String::from(backend_addr)));
206        }
207        Ok(())
208    }
209
210    /// Select a backend using the VIP's configured algorithm.
211    pub fn select_backend(
212        &mut self,
213        vip_addr: &str,
214        vip_port: u16,
215        client_ip: u32,
216    ) -> Result<(String, u16), L4Error> {
217        let key = Self::vip_key(vip_addr, vip_port);
218        let vip = self
219            .vips
220            .get_mut(&key)
221            .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
222
223        let healthy: Vec<usize> = vip
224            .backends
225            .iter()
226            .enumerate()
227            .filter(|(_, b)| b.healthy)
228            .map(|(i, _)| i)
229            .collect();
230
231        if healthy.is_empty() {
232            return Err(L4Error::NoHealthyBackend);
233        }
234
235        let idx = match vip.algorithm {
236            LbAlgorithm::RoundRobin => {
237                let i = (self.rr_counter as usize) % healthy.len();
238                self.rr_counter += 1;
239                healthy[i]
240            }
241            LbAlgorithm::LeastConnections => {
242                let mut min_idx = healthy[0];
243                let mut min_conns = vip.backends[healthy[0]].active_connections;
244                for &h in &healthy[1..] {
245                    if vip.backends[h].active_connections < min_conns {
246                        min_conns = vip.backends[h].active_connections;
247                        min_idx = h;
248                    }
249                }
250                min_idx
251            }
252            LbAlgorithm::WeightedRoundRobin => {
253                // Weighted selection: pick highest weight among healthy
254                let total_weight: u32 = healthy.iter().map(|&i| vip.backends[i].weight).sum();
255                if total_weight == 0 {
256                    healthy[0]
257                } else {
258                    let target = (self.rr_counter % total_weight as u64) as u32;
259                    self.rr_counter += 1;
260                    let mut cumulative = 0u32;
261                    let mut selected = healthy[0];
262                    for &h in &healthy {
263                        cumulative += vip.backends[h].weight;
264                        if target < cumulative {
265                            selected = h;
266                            break;
267                        }
268                    }
269                    selected
270                }
271            }
272            LbAlgorithm::Random => {
273                // LCG pseudo-random
274                self.random_state = self
275                    .random_state
276                    .wrapping_mul(6364136223846793005)
277                    .wrapping_add(1442695040888963407);
278                let i = (self.random_state >> 33) as usize % healthy.len();
279                healthy[i]
280            }
281            LbAlgorithm::IpHash => {
282                let hash = client_ip.wrapping_mul(2654435761);
283                let i = (hash as usize) % healthy.len();
284                healthy[i]
285            }
286        };
287
288        vip.backends[idx].active_connections += 1;
289        vip.backends[idx].total_requests += 1;
290
291        Ok((vip.backends[idx].address.clone(), vip.backends[idx].port))
292    }
293
294    /// Run health checks on all backends in all VIPs.
295    pub fn health_check(&mut self) {
296        // Simulated: just mark backends based on active connections
297        // In real code this would send probes.
298        for vip in self.vips.values_mut() {
299            for backend in &mut vip.backends {
300                // If too many connections, consider unhealthy
301                if backend.active_connections > 10000 {
302                    backend.healthy = false;
303                }
304            }
305        }
306    }
307
308    /// Mark a specific backend as healthy/unhealthy.
309    pub fn set_backend_health(
310        &mut self,
311        vip_addr: &str,
312        vip_port: u16,
313        backend_addr: &str,
314        healthy: bool,
315    ) -> Result<(), L4Error> {
316        let key = Self::vip_key(vip_addr, vip_port);
317        let vip = self
318            .vips
319            .get_mut(&key)
320            .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
321
322        for backend in &mut vip.backends {
323            if backend.address == backend_addr {
324                backend.healthy = healthy;
325                return Ok(());
326            }
327        }
328        Err(L4Error::BackendNotFound(String::from(backend_addr)))
329    }
330
331    /// Get VIP info.
332    pub fn get_vip(&self, vip_addr: &str, vip_port: u16) -> Option<&VirtualIp> {
333        let key = Self::vip_key(vip_addr, vip_port);
334        self.vips.get(&key)
335    }
336
337    /// List all VIPs.
338    pub fn list_vips(&self) -> Vec<&VirtualIp> {
339        self.vips.values().collect()
340    }
341
342    /// Get total number of VIPs.
343    pub fn vip_count(&self) -> usize {
344        self.vips.len()
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    #[allow(unused_imports)]
355    use alloc::string::ToString;
356
357    use super::*;
358
359    fn make_lb() -> L4LoadBalancer {
360        let mut lb = L4LoadBalancer::new();
361        lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
362            .unwrap();
363        lb.add_backend(
364            "10.96.0.1",
365            80,
366            Backend::new(String::from("10.0.0.1"), 8080, 1),
367        )
368        .unwrap();
369        lb.add_backend(
370            "10.96.0.1",
371            80,
372            Backend::new(String::from("10.0.0.2"), 8080, 1),
373        )
374        .unwrap();
375        lb
376    }
377
378    #[test]
379    fn test_add_vip() {
380        let mut lb = L4LoadBalancer::new();
381        lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
382            .unwrap();
383        assert_eq!(lb.vip_count(), 1);
384    }
385
386    #[test]
387    fn test_add_duplicate_vip() {
388        let mut lb = L4LoadBalancer::new();
389        lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
390            .unwrap();
391        assert!(lb
392            .add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
393            .is_err());
394    }
395
396    #[test]
397    fn test_remove_vip() {
398        let mut lb = make_lb();
399        lb.remove_vip("10.96.0.1", 80).unwrap();
400        assert_eq!(lb.vip_count(), 0);
401    }
402
403    #[test]
404    fn test_round_robin() {
405        let mut lb = make_lb();
406        let (addr1, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
407        let (addr2, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
408        assert_ne!(addr1, addr2);
409    }
410
411    #[test]
412    fn test_least_connections() {
413        let mut lb = L4LoadBalancer::new();
414        lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::LeastConnections)
415            .unwrap();
416        lb.add_backend(
417            "10.96.0.1",
418            80,
419            Backend::new(String::from("10.0.0.1"), 8080, 1),
420        )
421        .unwrap();
422        let mut b2 = Backend::new(String::from("10.0.0.2"), 8080, 1);
423        b2.active_connections = 5;
424        lb.add_backend("10.96.0.1", 80, b2).unwrap();
425
426        let (addr, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
427        assert_eq!(addr, "10.0.0.1"); // fewer connections
428    }
429
430    #[test]
431    fn test_ip_hash_deterministic() {
432        let mut lb = L4LoadBalancer::new();
433        lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::IpHash)
434            .unwrap();
435        lb.add_backend(
436            "10.96.0.1",
437            80,
438            Backend::new(String::from("10.0.0.1"), 8080, 1),
439        )
440        .unwrap();
441        lb.add_backend(
442            "10.96.0.1",
443            80,
444            Backend::new(String::from("10.0.0.2"), 8080, 1),
445        )
446        .unwrap();
447
448        let client_ip: u32 = 0xC0A80001; // 192.168.0.1
449        let (a1, _) = lb.select_backend("10.96.0.1", 80, client_ip).unwrap();
450        let (a2, _) = lb.select_backend("10.96.0.1", 80, client_ip).unwrap();
451        assert_eq!(a1, a2); // Same client IP -> same backend
452    }
453
454    #[test]
455    fn test_no_healthy_backend() {
456        let mut lb = make_lb();
457        lb.set_backend_health("10.96.0.1", 80, "10.0.0.1", false)
458            .unwrap();
459        lb.set_backend_health("10.96.0.1", 80, "10.0.0.2", false)
460            .unwrap();
461        assert_eq!(
462            lb.select_backend("10.96.0.1", 80, 0),
463            Err(L4Error::NoHealthyBackend)
464        );
465    }
466
467    #[test]
468    fn test_remove_backend() {
469        let mut lb = make_lb();
470        lb.remove_backend("10.96.0.1", 80, "10.0.0.1").unwrap();
471        let vip = lb.get_vip("10.96.0.1", 80).unwrap();
472        assert_eq!(vip.backends.len(), 1);
473    }
474
475    #[test]
476    fn test_weighted_round_robin() {
477        let mut lb = L4LoadBalancer::new();
478        lb.add_vip(
479            String::from("10.96.0.1"),
480            80,
481            LbAlgorithm::WeightedRoundRobin,
482        )
483        .unwrap();
484        lb.add_backend(
485            "10.96.0.1",
486            80,
487            Backend::new(String::from("10.0.0.1"), 8080, 3),
488        )
489        .unwrap();
490        lb.add_backend(
491            "10.96.0.1",
492            80,
493            Backend::new(String::from("10.0.0.2"), 8080, 1),
494        )
495        .unwrap();
496
497        // With weights 3:1, in 4 selections ~3 should go to first
498        let mut count_first = 0;
499        for _ in 0..4 {
500            let (addr, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
501            if addr == "10.0.0.1" {
502                count_first += 1;
503            }
504        }
505        assert!(count_first >= 2); // Should favor first backend
506    }
507
508    #[test]
509    fn test_vip_not_found() {
510        let mut lb = L4LoadBalancer::new();
511        assert_eq!(
512            lb.select_backend("10.96.0.1", 80, 0),
513            Err(L4Error::VipNotFound(String::from("10.96.0.1:80")))
514        );
515    }
516}