1use core::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8
9use super::{
10 capability::ProcessId,
11 error::{IpcError, Result},
12};
13use crate::arch::entropy::read_timestamp;
14
15pub struct RateLimiter {
17 buckets: [TokenBucket; MAX_PROCESSES],
19}
20
21const MAX_PROCESSES: usize = 1024;
23
24struct TokenBucket {
26 pid: AtomicU64,
28 tokens: AtomicU32,
30 max_tokens: AtomicU32,
32 refill_rate: AtomicU32,
34 last_refill: AtomicU64,
36 messages_sent: AtomicU64,
38 bytes_sent: AtomicU64,
40}
41
42impl TokenBucket {
43 const fn new() -> Self {
44 Self {
45 pid: AtomicU64::new(0),
46 tokens: AtomicU32::new(100),
47 max_tokens: AtomicU32::new(100),
48 refill_rate: AtomicU32::new(100),
49 last_refill: AtomicU64::new(0),
50 messages_sent: AtomicU64::new(0),
51 bytes_sent: AtomicU64::new(0),
52 }
53 }
54
55 fn try_consume(&self, tokens_needed: u32) -> bool {
57 self.refill();
59
60 let mut current = self.tokens.load(Ordering::Acquire);
62 loop {
63 if current < tokens_needed {
64 return false;
65 }
66
67 match self.tokens.compare_exchange_weak(
68 current,
69 current - tokens_needed,
70 Ordering::Release,
71 Ordering::Acquire,
72 ) {
73 Ok(_) => return true,
74 Err(val) => current = val,
75 }
76 }
77 }
78
79 fn refill(&self) {
81 let now = get_current_time();
82 let last = self.last_refill.load(Ordering::Acquire);
83 let elapsed_ms = (now - last) / 1_000_000; if elapsed_ms > 0 {
86 let refill_rate = self.refill_rate.load(Ordering::Relaxed);
87 let tokens_to_add = (refill_rate as u64 * elapsed_ms / 1000) as u32;
88
89 if tokens_to_add > 0 {
90 self.last_refill.store(now, Ordering::Release);
92
93 let max_tokens = self.max_tokens.load(Ordering::Relaxed);
95 let mut current = self.tokens.load(Ordering::Acquire);
96
97 loop {
98 let new_tokens = (current + tokens_to_add).min(max_tokens);
99 match self.tokens.compare_exchange_weak(
100 current,
101 new_tokens,
102 Ordering::Release,
103 Ordering::Acquire,
104 ) {
105 Ok(_) => break,
106 Err(val) => current = val,
107 }
108 }
109 }
110 }
111 }
112
113 fn reset(&self, pid: ProcessId, max_tokens: u32, refill_rate: u32) {
115 self.pid.store(pid.0, Ordering::Release);
116 self.tokens.store(max_tokens, Ordering::Release);
117 self.max_tokens.store(max_tokens, Ordering::Release);
118 self.refill_rate.store(refill_rate, Ordering::Release);
119 self.last_refill
120 .store(get_current_time(), Ordering::Release);
121 self.messages_sent.store(0, Ordering::Release);
122 self.bytes_sent.store(0, Ordering::Release);
123 }
124}
125
126impl RateLimiter {
127 pub const fn new() -> Self {
129 #[allow(clippy::declare_interior_mutable_const)]
131 const BUCKET: TokenBucket = TokenBucket::new();
132 Self {
133 buckets: [BUCKET; MAX_PROCESSES],
134 }
135 }
136
137 pub fn check_allowed(
139 &self,
140 pid: ProcessId,
141 message_size: usize,
142 limits: &RateLimits,
143 ) -> Result<()> {
144 let bucket = self.get_or_create_bucket(pid, limits)?;
146
147 if limits.max_messages_per_sec > 0 {
149 let tokens_needed = 1;
150 if !bucket.try_consume(tokens_needed) {
151 return Err(IpcError::RateLimitExceeded);
152 }
153 }
154
155 bucket.messages_sent.fetch_add(1, Ordering::Relaxed);
157 bucket
158 .bytes_sent
159 .fetch_add(message_size as u64, Ordering::Relaxed);
160
161 if limits.max_bytes_per_sec > 0 {
163 let bytes_sent = bucket.bytes_sent.load(Ordering::Relaxed);
164 if bytes_sent > limits.max_bytes_per_sec {
165 return Err(IpcError::RateLimitExceeded);
166 }
167 }
168
169 Ok(())
170 }
171
172 fn get_or_create_bucket(&self, pid: ProcessId, limits: &RateLimits) -> Result<&TokenBucket> {
174 let index = (pid.0 as usize) % MAX_PROCESSES;
176 let bucket = &self.buckets[index];
177
178 let current_pid = bucket.pid.load(Ordering::Acquire);
180 if current_pid == pid.0 {
181 return Ok(bucket);
182 }
183
184 if current_pid == 0 {
186 match bucket
187 .pid
188 .compare_exchange(0, pid.0, Ordering::Release, Ordering::Acquire)
189 {
190 Ok(_) => {
191 bucket.reset(
193 pid,
194 limits.max_messages_per_sec,
195 limits.max_messages_per_sec,
196 );
197 return Ok(bucket);
198 }
199 Err(_) => {
200 if bucket.pid.load(Ordering::Acquire) == pid.0 {
202 return Ok(bucket);
203 }
204 }
205 }
206 }
207
208 Ok(bucket)
211 }
212
213 pub fn get_stats(&self, pid: ProcessId) -> RateLimitStats {
215 let index = (pid.0 as usize) % MAX_PROCESSES;
216 let bucket = &self.buckets[index];
217
218 if bucket.pid.load(Ordering::Acquire) == pid.0 {
219 RateLimitStats {
220 messages_sent: bucket.messages_sent.load(Ordering::Relaxed),
221 bytes_sent: bucket.bytes_sent.load(Ordering::Relaxed),
222 tokens_available: bucket.tokens.load(Ordering::Relaxed),
223 max_tokens: bucket.max_tokens.load(Ordering::Relaxed),
224 }
225 } else {
226 RateLimitStats::default()
227 }
228 }
229}
230
231impl Default for RateLimiter {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[derive(Debug, Clone, Copy)]
239pub struct RateLimits {
240 pub max_messages_per_sec: u32,
242 pub max_bytes_per_sec: u64,
244 pub burst_multiplier: u32,
246}
247
248impl RateLimits {
249 pub const fn unlimited() -> Self {
251 Self {
252 max_messages_per_sec: 0,
253 max_bytes_per_sec: 0,
254 burst_multiplier: 1,
255 }
256 }
257
258 pub const fn default() -> Self {
260 Self {
261 max_messages_per_sec: 1000,
262 max_bytes_per_sec: 10 * 1024 * 1024, burst_multiplier: 2,
264 }
265 }
266
267 pub const fn strict() -> Self {
269 Self {
270 max_messages_per_sec: 100,
271 max_bytes_per_sec: 1024 * 1024, burst_multiplier: 1,
273 }
274 }
275}
276
277#[derive(Debug, Default)]
279pub struct RateLimitStats {
280 pub messages_sent: u64,
281 pub bytes_sent: u64,
282 pub tokens_available: u32,
283 pub max_tokens: u32,
284}
285
286#[allow(dead_code)]
288pub(crate) static RATE_LIMITER: RateLimiter = RateLimiter::new();
289
290fn get_current_time() -> u64 {
292 read_timestamp()
295}
296
297#[cfg(all(test, not(target_os = "none")))]
298mod tests {
299 use super::*;
300 use crate::process::ProcessId;
301
302 #[test]
303 fn test_token_bucket() {
304 let bucket = TokenBucket::new();
305 bucket.reset(ProcessId(1), 10, 10);
306
307 assert!(bucket.try_consume(5));
309 assert!(bucket.try_consume(5));
310
311 assert!(!bucket.try_consume(1));
313 }
314
315 #[test]
316 fn test_rate_limiter() {
317 let limits = RateLimits {
318 max_messages_per_sec: 10,
319 max_bytes_per_sec: 1000,
320 burst_multiplier: 1,
321 };
322
323 assert!(RATE_LIMITER
325 .check_allowed(ProcessId(1), 100, &limits)
326 .is_ok());
327
328 let stats = RATE_LIMITER.get_stats(ProcessId(1));
330 assert_eq!(stats.messages_sent, 1);
331 assert_eq!(stats.bytes_sent, 100);
332 }
333}