⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/ipc/
rate_limit.rs

1//! IPC rate limiting implementation
2//!
3//! Provides rate limiting for IPC operations to prevent DoS attacks
4//! and ensure fair resource usage.
5
6// IPC rate limiting
7use core::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8
9use super::{
10    capability::ProcessId,
11    error::{IpcError, Result},
12};
13use crate::arch::entropy::read_timestamp;
14
15/// Rate limiter for IPC operations
16pub struct RateLimiter {
17    /// Token bucket for rate limiting
18    buckets: [TokenBucket; MAX_PROCESSES],
19}
20
21/// Maximum number of processes to track
22const MAX_PROCESSES: usize = 1024;
23
24/// Token bucket for rate limiting
25struct TokenBucket {
26    /// Process ID this bucket belongs to
27    pid: AtomicU64,
28    /// Current number of tokens
29    tokens: AtomicU32,
30    /// Maximum tokens (bucket capacity)
31    max_tokens: AtomicU32,
32    /// Tokens per second refill rate
33    refill_rate: AtomicU32,
34    /// Last refill timestamp
35    last_refill: AtomicU64,
36    /// Messages sent in current window
37    messages_sent: AtomicU64,
38    /// Bytes sent in current window
39    bytes_sent: AtomicU64,
40}
41
42impl TokenBucket {
43    const fn new() -> Self {
44        Self {
45            pid: AtomicU64::new(0),
46            tokens: AtomicU32::new(100),
47            max_tokens: AtomicU32::new(100),
48            refill_rate: AtomicU32::new(100),
49            last_refill: AtomicU64::new(0),
50            messages_sent: AtomicU64::new(0),
51            bytes_sent: AtomicU64::new(0),
52        }
53    }
54
55    /// Try to consume tokens
56    fn try_consume(&self, tokens_needed: u32) -> bool {
57        // First refill tokens based on elapsed time
58        self.refill();
59
60        // Try to consume tokens
61        let mut current = self.tokens.load(Ordering::Acquire);
62        loop {
63            if current < tokens_needed {
64                return false;
65            }
66
67            match self.tokens.compare_exchange_weak(
68                current,
69                current - tokens_needed,
70                Ordering::Release,
71                Ordering::Acquire,
72            ) {
73                Ok(_) => return true,
74                Err(val) => current = val,
75            }
76        }
77    }
78
79    /// Refill tokens based on elapsed time
80    fn refill(&self) {
81        let now = get_current_time();
82        let last = self.last_refill.load(Ordering::Acquire);
83        let elapsed_ms = (now - last) / 1_000_000; // Convert ns to ms
84
85        if elapsed_ms > 0 {
86            let refill_rate = self.refill_rate.load(Ordering::Relaxed);
87            let tokens_to_add = (refill_rate as u64 * elapsed_ms / 1000) as u32;
88
89            if tokens_to_add > 0 {
90                // Update last refill time
91                self.last_refill.store(now, Ordering::Release);
92
93                // Add tokens, capping at max
94                let max_tokens = self.max_tokens.load(Ordering::Relaxed);
95                let mut current = self.tokens.load(Ordering::Acquire);
96
97                loop {
98                    let new_tokens = (current + tokens_to_add).min(max_tokens);
99                    match self.tokens.compare_exchange_weak(
100                        current,
101                        new_tokens,
102                        Ordering::Release,
103                        Ordering::Acquire,
104                    ) {
105                        Ok(_) => break,
106                        Err(val) => current = val,
107                    }
108                }
109            }
110        }
111    }
112
113    /// Reset the bucket for a new process
114    fn reset(&self, pid: ProcessId, max_tokens: u32, refill_rate: u32) {
115        self.pid.store(pid.0, Ordering::Release);
116        self.tokens.store(max_tokens, Ordering::Release);
117        self.max_tokens.store(max_tokens, Ordering::Release);
118        self.refill_rate.store(refill_rate, Ordering::Release);
119        self.last_refill
120            .store(get_current_time(), Ordering::Release);
121        self.messages_sent.store(0, Ordering::Release);
122        self.bytes_sent.store(0, Ordering::Release);
123    }
124}
125
126impl RateLimiter {
127    /// Create a new rate limiter
128    pub const fn new() -> Self {
129        // Can't use array initialization with const fn, so we'll do it manually
130        #[allow(clippy::declare_interior_mutable_const)]
131        const BUCKET: TokenBucket = TokenBucket::new();
132        Self {
133            buckets: [BUCKET; MAX_PROCESSES],
134        }
135    }
136
137    /// Check if an operation is allowed
138    pub fn check_allowed(
139        &self,
140        pid: ProcessId,
141        message_size: usize,
142        limits: &RateLimits,
143    ) -> Result<()> {
144        // Find or allocate bucket for this process
145        let bucket = self.get_or_create_bucket(pid, limits)?;
146
147        // Check message rate limit
148        if limits.max_messages_per_sec > 0 {
149            let tokens_needed = 1;
150            if !bucket.try_consume(tokens_needed) {
151                return Err(IpcError::RateLimitExceeded);
152            }
153        }
154
155        // Update statistics
156        bucket.messages_sent.fetch_add(1, Ordering::Relaxed);
157        bucket
158            .bytes_sent
159            .fetch_add(message_size as u64, Ordering::Relaxed);
160
161        // Check bandwidth limit
162        if limits.max_bytes_per_sec > 0 {
163            let bytes_sent = bucket.bytes_sent.load(Ordering::Relaxed);
164            if bytes_sent > limits.max_bytes_per_sec {
165                return Err(IpcError::RateLimitExceeded);
166            }
167        }
168
169        Ok(())
170    }
171
172    /// Get or create a bucket for a process
173    fn get_or_create_bucket(&self, pid: ProcessId, limits: &RateLimits) -> Result<&TokenBucket> {
174        // Hash the PID to a bucket index
175        let index = (pid.0 as usize) % MAX_PROCESSES;
176        let bucket = &self.buckets[index];
177
178        // Check if this bucket is for our process
179        let current_pid = bucket.pid.load(Ordering::Acquire);
180        if current_pid == pid.0 {
181            return Ok(bucket);
182        }
183
184        // Try to claim this bucket
185        if current_pid == 0 {
186            match bucket
187                .pid
188                .compare_exchange(0, pid.0, Ordering::Release, Ordering::Acquire)
189            {
190                Ok(_) => {
191                    // Successfully claimed, initialize it
192                    bucket.reset(
193                        pid,
194                        limits.max_messages_per_sec,
195                        limits.max_messages_per_sec,
196                    );
197                    return Ok(bucket);
198                }
199                Err(_) => {
200                    // Someone else claimed it, check if it's ours now
201                    if bucket.pid.load(Ordering::Acquire) == pid.0 {
202                        return Ok(bucket);
203                    }
204                }
205            }
206        }
207
208        // Bucket collision - for now, allow the operation
209        // In production, we'd implement a more sophisticated scheme
210        Ok(bucket)
211    }
212
213    /// Get statistics for a process
214    pub fn get_stats(&self, pid: ProcessId) -> RateLimitStats {
215        let index = (pid.0 as usize) % MAX_PROCESSES;
216        let bucket = &self.buckets[index];
217
218        if bucket.pid.load(Ordering::Acquire) == pid.0 {
219            RateLimitStats {
220                messages_sent: bucket.messages_sent.load(Ordering::Relaxed),
221                bytes_sent: bucket.bytes_sent.load(Ordering::Relaxed),
222                tokens_available: bucket.tokens.load(Ordering::Relaxed),
223                max_tokens: bucket.max_tokens.load(Ordering::Relaxed),
224            }
225        } else {
226            RateLimitStats::default()
227        }
228    }
229}
230
231impl Default for RateLimiter {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Rate limit configuration
238#[derive(Debug, Clone, Copy)]
239pub struct RateLimits {
240    /// Maximum messages per second (0 = unlimited)
241    pub max_messages_per_sec: u32,
242    /// Maximum bytes per second (0 = unlimited)
243    pub max_bytes_per_sec: u64,
244    /// Burst capacity multiplier
245    pub burst_multiplier: u32,
246}
247
248impl RateLimits {
249    /// Create unlimited rate limits
250    pub const fn unlimited() -> Self {
251        Self {
252            max_messages_per_sec: 0,
253            max_bytes_per_sec: 0,
254            burst_multiplier: 1,
255        }
256    }
257
258    /// Create default rate limits
259    pub const fn default() -> Self {
260        Self {
261            max_messages_per_sec: 1000,
262            max_bytes_per_sec: 10 * 1024 * 1024, // 10 MB/s
263            burst_multiplier: 2,
264        }
265    }
266
267    /// Create strict rate limits
268    pub const fn strict() -> Self {
269        Self {
270            max_messages_per_sec: 100,
271            max_bytes_per_sec: 1024 * 1024, // 1 MB/s
272            burst_multiplier: 1,
273        }
274    }
275}
276
277/// Rate limit statistics
278#[derive(Debug, Default)]
279pub struct RateLimitStats {
280    pub messages_sent: u64,
281    pub bytes_sent: u64,
282    pub tokens_available: u32,
283    pub max_tokens: u32,
284}
285
286/// Global rate limiter instance
287#[allow(dead_code)]
288pub(crate) static RATE_LIMITER: RateLimiter = RateLimiter::new();
289
290/// Get current time in nanoseconds
291fn get_current_time() -> u64 {
292    // Uses the centralized hardware timestamp counter from arch::entropy.
293    // On x86_64 this reads RDTSC, on AArch64 CNTVCT_EL0, on RISC-V rdcycle.
294    read_timestamp()
295}
296
297#[cfg(all(test, not(target_os = "none")))]
298mod tests {
299    use super::*;
300    use crate::process::ProcessId;
301
302    #[test]
303    fn test_token_bucket() {
304        let bucket = TokenBucket::new();
305        bucket.reset(ProcessId(1), 10, 10);
306
307        // Should be able to consume tokens
308        assert!(bucket.try_consume(5));
309        assert!(bucket.try_consume(5));
310
311        // Should fail - no tokens left
312        assert!(!bucket.try_consume(1));
313    }
314
315    #[test]
316    fn test_rate_limiter() {
317        let limits = RateLimits {
318            max_messages_per_sec: 10,
319            max_bytes_per_sec: 1000,
320            burst_multiplier: 1,
321        };
322
323        // Should allow initial messages
324        assert!(RATE_LIMITER
325            .check_allowed(ProcessId(1), 100, &limits)
326            .is_ok());
327
328        // Get stats
329        let stats = RATE_LIMITER.get_stats(ProcessId(1));
330        assert_eq!(stats.messages_sent, 1);
331        assert_eq!(stats.bytes_sent, 100);
332    }
333}