1#![allow(dead_code)]
8
9use alloc::{collections::BTreeMap, string::String, vec::Vec};
10
11#[derive(Debug, Clone)]
17pub struct Backend {
18 pub address: String,
20 pub port: u16,
22 pub weight: u32,
24 pub healthy: bool,
26 pub active_connections: u32,
28 pub total_requests: u64,
30}
31
32impl Backend {
33 pub fn new(address: String, port: u16, weight: u32) -> Self {
35 Backend {
36 address,
37 port,
38 weight,
39 healthy: true,
40 active_connections: 0,
41 total_requests: 0,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum LbAlgorithm {
53 #[default]
55 RoundRobin,
56 LeastConnections,
58 WeightedRoundRobin,
60 Random,
62 IpHash,
64}
65
66#[derive(Debug, Clone)]
72pub struct VirtualIp {
73 pub vip_addr: String,
75 pub vip_port: u16,
77 pub backends: Vec<Backend>,
79 pub algorithm: LbAlgorithm,
81}
82
83impl VirtualIp {
84 pub fn healthy_count(&self) -> usize {
86 self.backends.iter().filter(|b| b.healthy).count()
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum L4Error {
97 VipNotFound(String),
99 VipAlreadyExists(String),
101 NoHealthyBackend,
103 BackendNotFound(String),
105}
106
107#[derive(Debug)]
113pub struct L4LoadBalancer {
114 vips: BTreeMap<String, VirtualIp>,
116 rr_counter: u64,
119 random_state: u64,
121}
122
123impl Default for L4LoadBalancer {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl L4LoadBalancer {
130 pub fn new() -> Self {
132 L4LoadBalancer {
133 vips: BTreeMap::new(),
134 rr_counter: 0,
135 random_state: 0x12345678,
136 }
137 }
138
139 fn vip_key(addr: &str, port: u16) -> String {
141 alloc::format!("{}:{}", addr, port)
142 }
143
144 pub fn add_vip(
146 &mut self,
147 vip_addr: String,
148 vip_port: u16,
149 algorithm: LbAlgorithm,
150 ) -> Result<(), L4Error> {
151 let key = Self::vip_key(&vip_addr, vip_port);
152 if self.vips.contains_key(&key) {
153 return Err(L4Error::VipAlreadyExists(key));
154 }
155 self.vips.insert(
156 key,
157 VirtualIp {
158 vip_addr,
159 vip_port,
160 backends: Vec::new(),
161 algorithm,
162 },
163 );
164 Ok(())
165 }
166
167 pub fn remove_vip(&mut self, vip_addr: &str, vip_port: u16) -> Result<(), L4Error> {
169 let key = Self::vip_key(vip_addr, vip_port);
170 self.vips
171 .remove(&key)
172 .map(|_| ())
173 .ok_or(L4Error::VipNotFound(key))
174 }
175
176 pub fn add_backend(
178 &mut self,
179 vip_addr: &str,
180 vip_port: u16,
181 backend: Backend,
182 ) -> Result<(), L4Error> {
183 let key = Self::vip_key(vip_addr, vip_port);
184 let vip = self.vips.get_mut(&key).ok_or(L4Error::VipNotFound(key))?;
185 vip.backends.push(backend);
186 Ok(())
187 }
188
189 pub fn remove_backend(
191 &mut self,
192 vip_addr: &str,
193 vip_port: u16,
194 backend_addr: &str,
195 ) -> Result<(), L4Error> {
196 let key = Self::vip_key(vip_addr, vip_port);
197 let vip = self
198 .vips
199 .get_mut(&key)
200 .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
201
202 let before = vip.backends.len();
203 vip.backends.retain(|b| b.address != backend_addr);
204 if vip.backends.len() == before {
205 return Err(L4Error::BackendNotFound(String::from(backend_addr)));
206 }
207 Ok(())
208 }
209
210 pub fn select_backend(
212 &mut self,
213 vip_addr: &str,
214 vip_port: u16,
215 client_ip: u32,
216 ) -> Result<(String, u16), L4Error> {
217 let key = Self::vip_key(vip_addr, vip_port);
218 let vip = self
219 .vips
220 .get_mut(&key)
221 .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
222
223 let healthy: Vec<usize> = vip
224 .backends
225 .iter()
226 .enumerate()
227 .filter(|(_, b)| b.healthy)
228 .map(|(i, _)| i)
229 .collect();
230
231 if healthy.is_empty() {
232 return Err(L4Error::NoHealthyBackend);
233 }
234
235 let idx = match vip.algorithm {
236 LbAlgorithm::RoundRobin => {
237 let i = (self.rr_counter as usize) % healthy.len();
238 self.rr_counter += 1;
239 healthy[i]
240 }
241 LbAlgorithm::LeastConnections => {
242 let mut min_idx = healthy[0];
243 let mut min_conns = vip.backends[healthy[0]].active_connections;
244 for &h in &healthy[1..] {
245 if vip.backends[h].active_connections < min_conns {
246 min_conns = vip.backends[h].active_connections;
247 min_idx = h;
248 }
249 }
250 min_idx
251 }
252 LbAlgorithm::WeightedRoundRobin => {
253 let total_weight: u32 = healthy.iter().map(|&i| vip.backends[i].weight).sum();
255 if total_weight == 0 {
256 healthy[0]
257 } else {
258 let target = (self.rr_counter % total_weight as u64) as u32;
259 self.rr_counter += 1;
260 let mut cumulative = 0u32;
261 let mut selected = healthy[0];
262 for &h in &healthy {
263 cumulative += vip.backends[h].weight;
264 if target < cumulative {
265 selected = h;
266 break;
267 }
268 }
269 selected
270 }
271 }
272 LbAlgorithm::Random => {
273 self.random_state = self
275 .random_state
276 .wrapping_mul(6364136223846793005)
277 .wrapping_add(1442695040888963407);
278 let i = (self.random_state >> 33) as usize % healthy.len();
279 healthy[i]
280 }
281 LbAlgorithm::IpHash => {
282 let hash = client_ip.wrapping_mul(2654435761);
283 let i = (hash as usize) % healthy.len();
284 healthy[i]
285 }
286 };
287
288 vip.backends[idx].active_connections += 1;
289 vip.backends[idx].total_requests += 1;
290
291 Ok((vip.backends[idx].address.clone(), vip.backends[idx].port))
292 }
293
294 pub fn health_check(&mut self) {
296 for vip in self.vips.values_mut() {
299 for backend in &mut vip.backends {
300 if backend.active_connections > 10000 {
302 backend.healthy = false;
303 }
304 }
305 }
306 }
307
308 pub fn set_backend_health(
310 &mut self,
311 vip_addr: &str,
312 vip_port: u16,
313 backend_addr: &str,
314 healthy: bool,
315 ) -> Result<(), L4Error> {
316 let key = Self::vip_key(vip_addr, vip_port);
317 let vip = self
318 .vips
319 .get_mut(&key)
320 .ok_or_else(|| L4Error::VipNotFound(key.clone()))?;
321
322 for backend in &mut vip.backends {
323 if backend.address == backend_addr {
324 backend.healthy = healthy;
325 return Ok(());
326 }
327 }
328 Err(L4Error::BackendNotFound(String::from(backend_addr)))
329 }
330
331 pub fn get_vip(&self, vip_addr: &str, vip_port: u16) -> Option<&VirtualIp> {
333 let key = Self::vip_key(vip_addr, vip_port);
334 self.vips.get(&key)
335 }
336
337 pub fn list_vips(&self) -> Vec<&VirtualIp> {
339 self.vips.values().collect()
340 }
341
342 pub fn vip_count(&self) -> usize {
344 self.vips.len()
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 #[allow(unused_imports)]
355 use alloc::string::ToString;
356
357 use super::*;
358
359 fn make_lb() -> L4LoadBalancer {
360 let mut lb = L4LoadBalancer::new();
361 lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
362 .unwrap();
363 lb.add_backend(
364 "10.96.0.1",
365 80,
366 Backend::new(String::from("10.0.0.1"), 8080, 1),
367 )
368 .unwrap();
369 lb.add_backend(
370 "10.96.0.1",
371 80,
372 Backend::new(String::from("10.0.0.2"), 8080, 1),
373 )
374 .unwrap();
375 lb
376 }
377
378 #[test]
379 fn test_add_vip() {
380 let mut lb = L4LoadBalancer::new();
381 lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
382 .unwrap();
383 assert_eq!(lb.vip_count(), 1);
384 }
385
386 #[test]
387 fn test_add_duplicate_vip() {
388 let mut lb = L4LoadBalancer::new();
389 lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
390 .unwrap();
391 assert!(lb
392 .add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::RoundRobin)
393 .is_err());
394 }
395
396 #[test]
397 fn test_remove_vip() {
398 let mut lb = make_lb();
399 lb.remove_vip("10.96.0.1", 80).unwrap();
400 assert_eq!(lb.vip_count(), 0);
401 }
402
403 #[test]
404 fn test_round_robin() {
405 let mut lb = make_lb();
406 let (addr1, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
407 let (addr2, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
408 assert_ne!(addr1, addr2);
409 }
410
411 #[test]
412 fn test_least_connections() {
413 let mut lb = L4LoadBalancer::new();
414 lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::LeastConnections)
415 .unwrap();
416 lb.add_backend(
417 "10.96.0.1",
418 80,
419 Backend::new(String::from("10.0.0.1"), 8080, 1),
420 )
421 .unwrap();
422 let mut b2 = Backend::new(String::from("10.0.0.2"), 8080, 1);
423 b2.active_connections = 5;
424 lb.add_backend("10.96.0.1", 80, b2).unwrap();
425
426 let (addr, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
427 assert_eq!(addr, "10.0.0.1"); }
429
430 #[test]
431 fn test_ip_hash_deterministic() {
432 let mut lb = L4LoadBalancer::new();
433 lb.add_vip(String::from("10.96.0.1"), 80, LbAlgorithm::IpHash)
434 .unwrap();
435 lb.add_backend(
436 "10.96.0.1",
437 80,
438 Backend::new(String::from("10.0.0.1"), 8080, 1),
439 )
440 .unwrap();
441 lb.add_backend(
442 "10.96.0.1",
443 80,
444 Backend::new(String::from("10.0.0.2"), 8080, 1),
445 )
446 .unwrap();
447
448 let client_ip: u32 = 0xC0A80001; let (a1, _) = lb.select_backend("10.96.0.1", 80, client_ip).unwrap();
450 let (a2, _) = lb.select_backend("10.96.0.1", 80, client_ip).unwrap();
451 assert_eq!(a1, a2); }
453
454 #[test]
455 fn test_no_healthy_backend() {
456 let mut lb = make_lb();
457 lb.set_backend_health("10.96.0.1", 80, "10.0.0.1", false)
458 .unwrap();
459 lb.set_backend_health("10.96.0.1", 80, "10.0.0.2", false)
460 .unwrap();
461 assert_eq!(
462 lb.select_backend("10.96.0.1", 80, 0),
463 Err(L4Error::NoHealthyBackend)
464 );
465 }
466
467 #[test]
468 fn test_remove_backend() {
469 let mut lb = make_lb();
470 lb.remove_backend("10.96.0.1", 80, "10.0.0.1").unwrap();
471 let vip = lb.get_vip("10.96.0.1", 80).unwrap();
472 assert_eq!(vip.backends.len(), 1);
473 }
474
475 #[test]
476 fn test_weighted_round_robin() {
477 let mut lb = L4LoadBalancer::new();
478 lb.add_vip(
479 String::from("10.96.0.1"),
480 80,
481 LbAlgorithm::WeightedRoundRobin,
482 )
483 .unwrap();
484 lb.add_backend(
485 "10.96.0.1",
486 80,
487 Backend::new(String::from("10.0.0.1"), 8080, 3),
488 )
489 .unwrap();
490 lb.add_backend(
491 "10.96.0.1",
492 80,
493 Backend::new(String::from("10.0.0.2"), 8080, 1),
494 )
495 .unwrap();
496
497 let mut count_first = 0;
499 for _ in 0..4 {
500 let (addr, _) = lb.select_backend("10.96.0.1", 80, 0).unwrap();
501 if addr == "10.0.0.1" {
502 count_first += 1;
503 }
504 }
505 assert!(count_first >= 2); }
507
508 #[test]
509 fn test_vip_not_found() {
510 let mut lb = L4LoadBalancer::new();
511 assert_eq!(
512 lb.select_backend("10.96.0.1", 80, 0),
513 Err(L4Error::VipNotFound(String::from("10.96.0.1:80")))
514 );
515 }
516}