⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/
thread_api.rs

1//! Thread Management APIs
2//!
3//! High-level thread management interface for user-space applications.
4
5#![allow(
6    clippy::type_complexity,
7    clippy::arc_with_non_send_sync,
8    clippy::manual_div_ceil
9)]
10
11use alloc::{string::String, sync::Arc, vec::Vec};
12use core::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13
14use spin::{Mutex, RwLock};
15
16use crate::{
17    error::KernelError,
18    process::{
19        get_process,
20        thread::{ThreadFs, ThreadId},
21        ProcessId,
22    },
23}; // Use the ThreadId from process::thread module
24
25/// Thread priority levels
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ThreadPriority {
28    Idle = 0,
29    Low = 1,
30    Normal = 2,
31    High = 3,
32    RealTime = 4,
33}
34
35/// Thread state
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ThreadState {
38    Created,
39    Ready,
40    Running,
41    Blocked,
42    Suspended,
43    Terminated,
44}
45
46/// Thread scheduling policy
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SchedulingPolicy {
49    Normal,   // Standard round-robin with priorities
50    RealTime, // Real-time scheduling
51    Batch,    // Batch processing (lower priority)
52    Idle,     // Idle threads (run when nothing else)
53}
54
55/// Thread attributes
56#[derive(Debug, Clone)]
57pub struct ThreadAttributes {
58    /// Stack size in bytes
59    pub stack_size: usize,
60
61    /// Thread priority
62    pub priority: ThreadPriority,
63
64    /// Scheduling policy
65    pub policy: SchedulingPolicy,
66
67    /// CPU affinity mask (bit mask of allowed CPUs)
68    pub cpu_affinity: u64,
69
70    /// Thread name
71    pub name: String,
72
73    /// Detached state (true = detached, false = joinable)
74    pub detached: bool,
75
76    /// Inherit scheduling from parent
77    pub inherit_sched: bool,
78}
79
80impl Default for ThreadAttributes {
81    fn default() -> Self {
82        Self {
83            stack_size: 1024 * 1024, // 1 MB default stack
84            priority: ThreadPriority::Normal,
85            policy: SchedulingPolicy::Normal,
86            cpu_affinity: u64::MAX, // All CPUs
87            name: String::from("thread"),
88            detached: false,
89            inherit_sched: true,
90        }
91    }
92}
93
94/// Thread-local storage key
95#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
96pub struct TlsKey(u32);
97
98/// Thread entry point function type
99pub type ThreadEntryPoint = fn(*mut u8) -> *mut u8;
100
101/// Thread handle for management
102#[derive(Debug)]
103pub struct ThreadHandle {
104    pub id: ThreadId,
105    pub process_id: ProcessId,
106    pub state: RwLock<ThreadState>,
107    pub attributes: ThreadAttributes,
108    pub exit_value: Mutex<Option<*mut u8>>,
109    pub joinable: AtomicBool,
110    pub cpu_time: AtomicU64,
111    pub context_switches: AtomicU64,
112    /// Cancellation requested flag
113    pub cancel_requested: AtomicBool,
114}
115
116impl ThreadHandle {
117    /// Create a new thread handle
118    pub fn new(id: ThreadId, process_id: ProcessId, attributes: ThreadAttributes) -> Self {
119        Self {
120            id,
121            process_id,
122            state: RwLock::new(ThreadState::Created),
123            joinable: AtomicBool::new(!attributes.detached),
124            exit_value: Mutex::new(None),
125            cpu_time: AtomicU64::new(0),
126            context_switches: AtomicU64::new(0),
127            cancel_requested: AtomicBool::new(false),
128            attributes,
129        }
130    }
131
132    /// Get thread state
133    pub fn get_state(&self) -> ThreadState {
134        *self.state.read()
135    }
136
137    /// Set thread state
138    pub fn set_state(&self, state: ThreadState) {
139        *self.state.write() = state;
140    }
141
142    /// Check if thread is joinable
143    pub fn is_joinable(&self) -> bool {
144        self.joinable.load(Ordering::Acquire)
145    }
146
147    /// Detach the thread
148    pub fn detach(&self) {
149        self.joinable.store(false, Ordering::Release);
150    }
151
152    /// Get CPU time used by thread
153    pub fn get_cpu_time(&self) -> u64 {
154        self.cpu_time.load(Ordering::Relaxed)
155    }
156
157    /// Get number of context switches
158    pub fn get_context_switches(&self) -> u64 {
159        self.context_switches.load(Ordering::Relaxed)
160    }
161
162    /// Check if cancellation was requested
163    pub fn is_cancel_requested(&self) -> bool {
164        self.cancel_requested.load(Ordering::Acquire)
165    }
166
167    /// Request cancellation
168    pub fn request_cancel(&self) {
169        self.cancel_requested.store(true, Ordering::Release);
170    }
171}
172
173/// Thread creation parameters
174pub struct ThreadCreateParams {
175    pub entry_point: ThreadEntryPoint,
176    pub arg: *mut u8,
177    pub attributes: ThreadAttributes,
178}
179
180/// Thread management system
181pub struct ThreadManager {
182    /// Thread counter for ID generation
183    next_thread_id: AtomicU64,
184
185    /// TLS key counter
186    next_tls_key: AtomicU64,
187
188    /// Global thread table
189    threads: RwLock<alloc::collections::BTreeMap<ThreadId, Arc<ThreadHandle>>>,
190
191    /// TLS destructors
192    tls_destructors: RwLock<alloc::collections::BTreeMap<TlsKey, fn(*mut u8)>>,
193}
194
195// SAFETY: ThreadManager is safe to send between threads
196// All fields are either atomic or protected by RwLock
197unsafe impl Send for ThreadManager {}
198
199// SAFETY: ThreadManager is safe to share between threads
200// All mutations are protected by atomic operations or RwLock
201unsafe impl Sync for ThreadManager {}
202
203impl Default for ThreadManager {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl ThreadManager {
210    /// Create a new thread manager
211    pub fn new() -> Self {
212        Self {
213            next_thread_id: AtomicU64::new(1),
214            next_tls_key: AtomicU64::new(1),
215            threads: RwLock::new(alloc::collections::BTreeMap::new()),
216            tls_destructors: RwLock::new(alloc::collections::BTreeMap::new()),
217        }
218    }
219
220    /// Create a new thread
221    pub fn create_thread(
222        &self,
223        params: ThreadCreateParams,
224        process_id: ProcessId,
225    ) -> Result<Arc<ThreadHandle>, KernelError> {
226        let thread_id = ThreadId(self.next_thread_id.fetch_add(1, Ordering::SeqCst));
227
228        // Create thread handle
229        let handle = Arc::new(ThreadHandle::new(
230            thread_id,
231            process_id,
232            params.attributes.clone(),
233        ));
234
235        // Add to thread table
236        self.threads.write().insert(thread_id, handle.clone());
237
238        // Create actual thread in the process
239        if let Some(process) = get_process(process_id) {
240            // Allocate stack first
241            let stack_size = params.attributes.stack_size;
242            let mut memory_space = process.memory_space.lock();
243
244            // Find a suitable virtual address for the stack
245            let stack_base = 0x70000000; // User stack area
246            let current_thread_count = process.thread_count();
247            let stack_addr = stack_base - (current_thread_count * stack_size);
248
249            // Map stack pages
250            let page_count = (stack_size + 4095) / 4096;
251            for i in 0..page_count {
252                let page_addr = stack_addr + (i * 4096);
253                let page_flags = crate::mm::PageFlags::PRESENT
254                    | crate::mm::PageFlags::USER
255                    | crate::mm::PageFlags::NO_EXECUTE;
256                memory_space.map_page(page_addr, page_flags)?;
257            }
258            drop(memory_space); // Release the lock
259
260            // Create the actual Thread object
261            use crate::process::thread::Thread;
262
263            let kernel_stack_size = 64 * 1024; // 64KB kernel stack
264            let kernel_stack_base = 0x80000000; // Kernel stack area
265            let kernel_stack_addr = kernel_stack_base - (current_thread_count * kernel_stack_size);
266
267            let thread = Thread::new(
268                thread_id,
269                process_id,
270                params.attributes.name.clone(),
271                params.entry_point as usize,
272                stack_addr,
273                stack_size,
274                kernel_stack_addr,
275                kernel_stack_size,
276                ThreadFs::new_root(),
277            );
278
279            // Set stack pointer to top of stack (with argument)
280            let stack_top = stack_addr + stack_size - 8; // Leave space for return address
281            thread.user_stack.set_sp(stack_top);
282
283            // Store thread argument at top of stack
284            // SAFETY: `stack_top` is a virtual address within the
285            // user-space stack region that was just mapped above
286            // (stack_addr .. stack_addr + stack_size).  Writing
287            // one pointer-sized value at the top of the stack is
288            // within bounds.  The stack pages are mapped writable.
289            unsafe {
290                let arg_ptr = stack_top as *mut *mut u8;
291                *arg_ptr = params.arg;
292            }
293
294            // Set CPU affinity
295            thread.set_affinity(params.attributes.cpu_affinity as usize);
296
297            // Add thread to process
298            process.add_thread(thread)?;
299
300            handle.set_state(ThreadState::Ready);
301
302            crate::println!(
303                "[THREAD] Created thread {} in process {} with stack at 0x{:x}",
304                thread_id.0,
305                process_id.0,
306                stack_addr
307            );
308        } else {
309            return Err(KernelError::ProcessNotFound { pid: process_id.0 });
310        }
311
312        Ok(handle)
313    }
314
315    /// Get thread handle by ID
316    pub fn get_thread(&self, thread_id: ThreadId) -> Option<Arc<ThreadHandle>> {
317        self.threads.read().get(&thread_id).cloned()
318    }
319
320    /// Join a thread (wait for it to complete)
321    pub fn join_thread(&self, thread_id: ThreadId) -> Result<*mut u8, KernelError> {
322        let handle = self
323            .get_thread(thread_id)
324            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
325
326        if !handle.is_joinable() {
327            return Err(KernelError::InvalidState {
328                expected: "joinable",
329                actual: "detached",
330            });
331        }
332
333        // Wait for thread to complete
334        loop {
335            if handle.get_state() == ThreadState::Terminated {
336                let exit_value = handle
337                    .exit_value
338                    .lock()
339                    .take()
340                    .unwrap_or(core::ptr::null_mut());
341
342                // Remove from thread table
343                self.threads.write().remove(&thread_id);
344
345                return Ok(exit_value);
346            }
347
348            // Yield to scheduler
349            crate::sched::yield_cpu();
350        }
351    }
352
353    /// Detach a thread
354    pub fn detach_thread(&self, thread_id: ThreadId) -> Result<(), KernelError> {
355        let handle = self
356            .get_thread(thread_id)
357            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
358
359        handle.detach();
360        Ok(())
361    }
362
363    /// Cancel a thread
364    ///
365    /// Sets the cancellation flag on the thread. The thread should check
366    /// is_cancel_requested() periodically and exit gracefully when cancelled.
367    /// If the thread is blocked, it will be moved to Ready state to allow
368    /// it to process the cancellation.
369    pub fn cancel_thread(&self, thread_id: ThreadId) -> Result<(), KernelError> {
370        let handle = self
371            .get_thread(thread_id)
372            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
373
374        // Set cancellation flag - thread should check this and exit gracefully
375        handle.request_cancel();
376
377        // If thread is blocked, wake it up so it can process cancellation
378        if handle.get_state() == ThreadState::Blocked {
379            handle.set_state(ThreadState::Ready);
380        }
381
382        crate::println!("[THREAD] Cancellation requested for thread {}", thread_id.0);
383
384        Ok(())
385    }
386
387    /// Exit current thread
388    pub fn exit_thread(&self, thread_id: ThreadId, exit_value: *mut u8) {
389        if let Some(handle) = self.get_thread(thread_id) {
390            *handle.exit_value.lock() = Some(exit_value);
391            handle.set_state(ThreadState::Terminated);
392
393            // Run TLS destructors
394            self.run_tls_destructors(thread_id);
395
396            crate::println!("[THREAD] Thread {} exited", thread_id.0);
397        }
398    }
399
400    /// Set thread priority
401    pub fn set_thread_priority(
402        &self,
403        thread_id: ThreadId,
404        _priority: ThreadPriority,
405    ) -> Result<(), KernelError> {
406        let handle = self
407            .get_thread(thread_id)
408            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
409
410        // Update process thread priority
411        if let Some(process) = get_process(handle.process_id) {
412            if let Some(_thread) = process.get_thread(thread_id) {
413                // Update thread priority (stored in the thread object)
414                // Note: In a real implementation, we'd need mutable access
415                // For now, just track it in the handle attributes
416                crate::println!(
417                    "[THREAD] Updated thread {} priority (tracked in handle)",
418                    thread_id.0
419                );
420            }
421        }
422
423        crate::println!(
424            "[THREAD] Set thread {} priority to {:?}",
425            thread_id.0,
426            _priority
427        );
428        Ok(())
429    }
430
431    /// Get thread priority
432    pub fn get_thread_priority(&self, thread_id: ThreadId) -> Result<ThreadPriority, KernelError> {
433        let handle = self
434            .get_thread(thread_id)
435            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
436
437        if let Some(process) = get_process(handle.process_id) {
438            if let Some(thread) = process.get_thread(thread_id) {
439                // For now, map from the thread's priority field
440                let priority = match thread.priority {
441                    0 => ThreadPriority::RealTime,
442                    1 => ThreadPriority::High,
443                    2 => ThreadPriority::Normal,
444                    3 => ThreadPriority::Low,
445                    _ => ThreadPriority::Idle,
446                };
447                return Ok(priority);
448            }
449        }
450
451        Err(KernelError::NotFound {
452            resource: "thread context",
453            id: thread_id.0,
454        })
455    }
456
457    /// Set CPU affinity for thread
458    pub fn set_cpu_affinity(&self, thread_id: ThreadId, cpu_mask: u64) -> Result<(), KernelError> {
459        let handle = self
460            .get_thread(thread_id)
461            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
462
463        if let Some(process) = get_process(handle.process_id) {
464            if let Some(thread) = process.get_thread(thread_id) {
465                thread.set_affinity(cpu_mask as usize);
466                crate::println!(
467                    "[THREAD] Set thread {} CPU affinity to 0x{:x}",
468                    thread_id.0,
469                    cpu_mask
470                );
471                return Ok(());
472            }
473        }
474
475        Err(KernelError::NotFound {
476            resource: "thread context",
477            id: thread_id.0,
478        })
479    }
480
481    /// Create thread-local storage key
482    pub fn create_tls_key(&self, destructor: Option<fn(*mut u8)>) -> Result<TlsKey, KernelError> {
483        let key = TlsKey(self.next_tls_key.fetch_add(1, Ordering::SeqCst) as u32);
484
485        if let Some(dtor) = destructor {
486            self.tls_destructors.write().insert(key, dtor);
487        }
488
489        Ok(key)
490    }
491
492    /// Delete thread-local storage key
493    pub fn delete_tls_key(&self, key: TlsKey) -> Result<(), KernelError> {
494        self.tls_destructors.write().remove(&key);
495
496        // Remove from all thread contexts
497        let threads = self.threads.read();
498        for handle in threads.values() {
499            if let Some(process) = get_process(handle.process_id) {
500                if let Some(thread) = process.get_thread(handle.id) {
501                    #[cfg(feature = "alloc")]
502                    {
503                        thread.remove_tls_value(key.0 as u64);
504                    }
505                }
506            }
507        }
508
509        Ok(())
510    }
511
512    /// Set thread-local storage value
513    pub fn set_tls_value(
514        &self,
515        thread_id: ThreadId,
516        key: TlsKey,
517        value: *mut u8,
518    ) -> Result<(), KernelError> {
519        let handle = self
520            .get_thread(thread_id)
521            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
522
523        if let Some(process) = get_process(handle.process_id) {
524            if let Some(thread) = process.get_thread(thread_id) {
525                #[cfg(feature = "alloc")]
526                {
527                    thread.set_tls_value(key.0 as u64, value as u64);
528                    return Ok(());
529                }
530            }
531        }
532
533        Err(KernelError::NotFound {
534            resource: "thread context",
535            id: thread_id.0,
536        })
537    }
538
539    /// Get thread-local storage value
540    pub fn get_tls_value(&self, thread_id: ThreadId, key: TlsKey) -> Result<*mut u8, KernelError> {
541        let handle = self
542            .get_thread(thread_id)
543            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
544
545        if let Some(process) = get_process(handle.process_id) {
546            if let Some(thread) = process.get_thread(thread_id) {
547                #[cfg(feature = "alloc")]
548                {
549                    let value = thread.get_tls_value(key.0 as u64).unwrap_or(0) as *mut u8;
550                    return Ok(value);
551                }
552
553                #[cfg(not(feature = "alloc"))]
554                return Ok(core::ptr::null_mut());
555            }
556        }
557
558        Err(KernelError::NotFound {
559            resource: "thread context",
560            id: thread_id.0,
561        })
562    }
563
564    /// Get current thread ID from scheduler
565    pub fn get_current_thread_id(&self) -> Option<ThreadId> {
566        // Get from scheduler
567        let tid = crate::sched::get_current_thread_id();
568        if tid != 0 {
569            Some(ThreadId(tid))
570        } else {
571            None
572        }
573    }
574
575    /// List all threads
576    pub fn list_threads(&self) -> Vec<ThreadId> {
577        self.threads.read().keys().copied().collect()
578    }
579
580    /// Get thread statistics
581    pub fn get_thread_stats(&self, thread_id: ThreadId) -> Result<ThreadStats, KernelError> {
582        let handle = self
583            .get_thread(thread_id)
584            .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
585
586        Ok(ThreadStats {
587            id: thread_id,
588            process_id: handle.process_id,
589            state: handle.get_state(),
590            priority: handle.attributes.priority,
591            cpu_time: handle.get_cpu_time(),
592            context_switches: handle.get_context_switches(),
593            stack_size: handle.attributes.stack_size,
594            name: handle.attributes.name.clone(),
595        })
596    }
597
598    // Helper functions
599
600    fn run_tls_destructors(&self, thread_id: ThreadId) {
601        let destructors = self.tls_destructors.read();
602
603        if let Some(handle) = self.get_thread(thread_id) {
604            if let Some(process) = get_process(handle.process_id) {
605                if let Some(thread) = process.get_thread(thread_id) {
606                    #[cfg(feature = "alloc")]
607                    {
608                        for (key, dtor) in destructors.iter() {
609                            if let Some(value) = thread.get_tls_value(key.0 as u64) {
610                                if value != 0 {
611                                    dtor(value as *mut u8);
612                                }
613                            }
614                        }
615                    }
616                }
617            }
618        }
619    }
620}
621
622/// Thread statistics
623#[derive(Debug, Clone)]
624pub struct ThreadStats {
625    pub id: ThreadId,
626    pub process_id: ProcessId,
627    pub state: ThreadState,
628    pub priority: ThreadPriority,
629    pub cpu_time: u64,
630    pub context_switches: u64,
631    pub stack_size: usize,
632    pub name: String,
633}
634
635/// Global thread manager using OnceLock for safe initialization.
636static THREAD_MANAGER: crate::sync::once_lock::OnceLock<ThreadManager> =
637    crate::sync::once_lock::OnceLock::new();
638
639/// Initialize the thread manager
640pub fn init() {
641    #[allow(unused_imports)]
642    use crate::println;
643
644    println!("[THREAD_API] Creating new ThreadManager...");
645    match THREAD_MANAGER.set(ThreadManager::new()) {
646        Ok(()) => println!("[THREAD_API] Thread management APIs initialized"),
647        Err(_) => println!("[THREAD_API] Already initialized, skipping..."),
648    }
649}
650
651/// Try to get the global thread manager without panicking.
652///
653/// Returns `None` if the thread manager has not been initialized via [`init`].
654pub fn try_get_thread_manager() -> Option<&'static ThreadManager> {
655    THREAD_MANAGER.get()
656}
657
658/// Get the global thread manager.
659///
660/// Panics if the thread manager has not been initialized via [`init`].
661/// Prefer [`try_get_thread_manager`] in contexts where a panic is unacceptable.
662pub fn get_thread_manager() -> &'static ThreadManager {
663    THREAD_MANAGER
664        .get()
665        .expect("Thread manager not initialized: init() was not called")
666}
667
668// Convenience functions
669
670/// Create a new thread
671pub fn create_thread(
672    entry_point: ThreadEntryPoint,
673    arg: *mut u8,
674    attributes: ThreadAttributes,
675    process_id: ProcessId,
676) -> crate::error::KernelResult<Arc<ThreadHandle>> {
677    let params = ThreadCreateParams {
678        entry_point,
679        arg,
680        attributes,
681    };
682
683    get_thread_manager().create_thread(params, process_id)
684}
685
686/// Join a thread
687pub fn join_thread(thread_id: ThreadId) -> crate::error::KernelResult<*mut u8> {
688    get_thread_manager().join_thread(thread_id)
689}
690
691/// Exit current thread
692pub fn exit_thread(exit_value: *mut u8) -> ! {
693    // Get current thread ID from scheduler
694    let current_thread_id = ThreadId(crate::sched::get_current_thread_id());
695    get_thread_manager().exit_thread(current_thread_id, exit_value);
696
697    // Schedule next thread
698    crate::sched::yield_cpu();
699
700    // Should never reach here
701    loop {
702        core::hint::spin_loop();
703    }
704}
705
706/// Yield CPU to scheduler
707pub fn yield_thread() {
708    crate::sched::yield_cpu();
709}
710
711/// Sleep for a number of milliseconds
712///
713/// Uses timer-based sleep to avoid busy-waiting. The thread will yield
714/// to the scheduler while waiting for the sleep duration to elapse.
715pub fn sleep_ms(ms: u64) {
716    use crate::arch::timer::get_ticks;
717
718    // Get current time in ticks
719    let start_ticks = get_ticks();
720
721    // Convert milliseconds to approximate tick count
722    // Assuming ~1000 ticks per second (typical timer frequency)
723    // This may vary by architecture - adjust tick_rate as needed
724    const TICKS_PER_MS: u64 = 1;
725    let target_ticks = ms.saturating_mul(TICKS_PER_MS);
726    let end_ticks = start_ticks.saturating_add(target_ticks);
727
728    // Sleep loop - yield to scheduler while waiting
729    while get_ticks() < end_ticks {
730        // Yield to allow other threads to run
731        yield_thread();
732
733        // Brief spin to avoid hammering the scheduler
734        for _ in 0..100 {
735            core::hint::spin_loop();
736        }
737    }
738}