1#![allow(
6 clippy::type_complexity,
7 clippy::arc_with_non_send_sync,
8 clippy::manual_div_ceil
9)]
10
11use alloc::{string::String, sync::Arc, vec::Vec};
12use core::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13
14use spin::{Mutex, RwLock};
15
16use crate::{
17 error::KernelError,
18 process::{
19 get_process,
20 thread::{ThreadFs, ThreadId},
21 ProcessId,
22 },
23}; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ThreadPriority {
28 Idle = 0,
29 Low = 1,
30 Normal = 2,
31 High = 3,
32 RealTime = 4,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ThreadState {
38 Created,
39 Ready,
40 Running,
41 Blocked,
42 Suspended,
43 Terminated,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SchedulingPolicy {
49 Normal, RealTime, Batch, Idle, }
54
55#[derive(Debug, Clone)]
57pub struct ThreadAttributes {
58 pub stack_size: usize,
60
61 pub priority: ThreadPriority,
63
64 pub policy: SchedulingPolicy,
66
67 pub cpu_affinity: u64,
69
70 pub name: String,
72
73 pub detached: bool,
75
76 pub inherit_sched: bool,
78}
79
80impl Default for ThreadAttributes {
81 fn default() -> Self {
82 Self {
83 stack_size: 1024 * 1024, priority: ThreadPriority::Normal,
85 policy: SchedulingPolicy::Normal,
86 cpu_affinity: u64::MAX, name: String::from("thread"),
88 detached: false,
89 inherit_sched: true,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
96pub struct TlsKey(u32);
97
98pub type ThreadEntryPoint = fn(*mut u8) -> *mut u8;
100
101#[derive(Debug)]
103pub struct ThreadHandle {
104 pub id: ThreadId,
105 pub process_id: ProcessId,
106 pub state: RwLock<ThreadState>,
107 pub attributes: ThreadAttributes,
108 pub exit_value: Mutex<Option<*mut u8>>,
109 pub joinable: AtomicBool,
110 pub cpu_time: AtomicU64,
111 pub context_switches: AtomicU64,
112 pub cancel_requested: AtomicBool,
114}
115
116impl ThreadHandle {
117 pub fn new(id: ThreadId, process_id: ProcessId, attributes: ThreadAttributes) -> Self {
119 Self {
120 id,
121 process_id,
122 state: RwLock::new(ThreadState::Created),
123 joinable: AtomicBool::new(!attributes.detached),
124 exit_value: Mutex::new(None),
125 cpu_time: AtomicU64::new(0),
126 context_switches: AtomicU64::new(0),
127 cancel_requested: AtomicBool::new(false),
128 attributes,
129 }
130 }
131
132 pub fn get_state(&self) -> ThreadState {
134 *self.state.read()
135 }
136
137 pub fn set_state(&self, state: ThreadState) {
139 *self.state.write() = state;
140 }
141
142 pub fn is_joinable(&self) -> bool {
144 self.joinable.load(Ordering::Acquire)
145 }
146
147 pub fn detach(&self) {
149 self.joinable.store(false, Ordering::Release);
150 }
151
152 pub fn get_cpu_time(&self) -> u64 {
154 self.cpu_time.load(Ordering::Relaxed)
155 }
156
157 pub fn get_context_switches(&self) -> u64 {
159 self.context_switches.load(Ordering::Relaxed)
160 }
161
162 pub fn is_cancel_requested(&self) -> bool {
164 self.cancel_requested.load(Ordering::Acquire)
165 }
166
167 pub fn request_cancel(&self) {
169 self.cancel_requested.store(true, Ordering::Release);
170 }
171}
172
173pub struct ThreadCreateParams {
175 pub entry_point: ThreadEntryPoint,
176 pub arg: *mut u8,
177 pub attributes: ThreadAttributes,
178}
179
180pub struct ThreadManager {
182 next_thread_id: AtomicU64,
184
185 next_tls_key: AtomicU64,
187
188 threads: RwLock<alloc::collections::BTreeMap<ThreadId, Arc<ThreadHandle>>>,
190
191 tls_destructors: RwLock<alloc::collections::BTreeMap<TlsKey, fn(*mut u8)>>,
193}
194
195unsafe impl Send for ThreadManager {}
198
199unsafe impl Sync for ThreadManager {}
202
203impl Default for ThreadManager {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl ThreadManager {
210 pub fn new() -> Self {
212 Self {
213 next_thread_id: AtomicU64::new(1),
214 next_tls_key: AtomicU64::new(1),
215 threads: RwLock::new(alloc::collections::BTreeMap::new()),
216 tls_destructors: RwLock::new(alloc::collections::BTreeMap::new()),
217 }
218 }
219
220 pub fn create_thread(
222 &self,
223 params: ThreadCreateParams,
224 process_id: ProcessId,
225 ) -> Result<Arc<ThreadHandle>, KernelError> {
226 let thread_id = ThreadId(self.next_thread_id.fetch_add(1, Ordering::SeqCst));
227
228 let handle = Arc::new(ThreadHandle::new(
230 thread_id,
231 process_id,
232 params.attributes.clone(),
233 ));
234
235 self.threads.write().insert(thread_id, handle.clone());
237
238 if let Some(process) = get_process(process_id) {
240 let stack_size = params.attributes.stack_size;
242 let mut memory_space = process.memory_space.lock();
243
244 let stack_base = 0x70000000; let current_thread_count = process.thread_count();
247 let stack_addr = stack_base - (current_thread_count * stack_size);
248
249 let page_count = (stack_size + 4095) / 4096;
251 for i in 0..page_count {
252 let page_addr = stack_addr + (i * 4096);
253 let page_flags = crate::mm::PageFlags::PRESENT
254 | crate::mm::PageFlags::USER
255 | crate::mm::PageFlags::NO_EXECUTE;
256 memory_space.map_page(page_addr, page_flags)?;
257 }
258 drop(memory_space); use crate::process::thread::Thread;
262
263 let kernel_stack_size = 64 * 1024; let kernel_stack_base = 0x80000000; let kernel_stack_addr = kernel_stack_base - (current_thread_count * kernel_stack_size);
266
267 let thread = Thread::new(
268 thread_id,
269 process_id,
270 params.attributes.name.clone(),
271 params.entry_point as usize,
272 stack_addr,
273 stack_size,
274 kernel_stack_addr,
275 kernel_stack_size,
276 ThreadFs::new_root(),
277 );
278
279 let stack_top = stack_addr + stack_size - 8; thread.user_stack.set_sp(stack_top);
282
283 unsafe {
290 let arg_ptr = stack_top as *mut *mut u8;
291 *arg_ptr = params.arg;
292 }
293
294 thread.set_affinity(params.attributes.cpu_affinity as usize);
296
297 process.add_thread(thread)?;
299
300 handle.set_state(ThreadState::Ready);
301
302 crate::println!(
303 "[THREAD] Created thread {} in process {} with stack at 0x{:x}",
304 thread_id.0,
305 process_id.0,
306 stack_addr
307 );
308 } else {
309 return Err(KernelError::ProcessNotFound { pid: process_id.0 });
310 }
311
312 Ok(handle)
313 }
314
315 pub fn get_thread(&self, thread_id: ThreadId) -> Option<Arc<ThreadHandle>> {
317 self.threads.read().get(&thread_id).cloned()
318 }
319
320 pub fn join_thread(&self, thread_id: ThreadId) -> Result<*mut u8, KernelError> {
322 let handle = self
323 .get_thread(thread_id)
324 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
325
326 if !handle.is_joinable() {
327 return Err(KernelError::InvalidState {
328 expected: "joinable",
329 actual: "detached",
330 });
331 }
332
333 loop {
335 if handle.get_state() == ThreadState::Terminated {
336 let exit_value = handle
337 .exit_value
338 .lock()
339 .take()
340 .unwrap_or(core::ptr::null_mut());
341
342 self.threads.write().remove(&thread_id);
344
345 return Ok(exit_value);
346 }
347
348 crate::sched::yield_cpu();
350 }
351 }
352
353 pub fn detach_thread(&self, thread_id: ThreadId) -> Result<(), KernelError> {
355 let handle = self
356 .get_thread(thread_id)
357 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
358
359 handle.detach();
360 Ok(())
361 }
362
363 pub fn cancel_thread(&self, thread_id: ThreadId) -> Result<(), KernelError> {
370 let handle = self
371 .get_thread(thread_id)
372 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
373
374 handle.request_cancel();
376
377 if handle.get_state() == ThreadState::Blocked {
379 handle.set_state(ThreadState::Ready);
380 }
381
382 crate::println!("[THREAD] Cancellation requested for thread {}", thread_id.0);
383
384 Ok(())
385 }
386
387 pub fn exit_thread(&self, thread_id: ThreadId, exit_value: *mut u8) {
389 if let Some(handle) = self.get_thread(thread_id) {
390 *handle.exit_value.lock() = Some(exit_value);
391 handle.set_state(ThreadState::Terminated);
392
393 self.run_tls_destructors(thread_id);
395
396 crate::println!("[THREAD] Thread {} exited", thread_id.0);
397 }
398 }
399
400 pub fn set_thread_priority(
402 &self,
403 thread_id: ThreadId,
404 _priority: ThreadPriority,
405 ) -> Result<(), KernelError> {
406 let handle = self
407 .get_thread(thread_id)
408 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
409
410 if let Some(process) = get_process(handle.process_id) {
412 if let Some(_thread) = process.get_thread(thread_id) {
413 crate::println!(
417 "[THREAD] Updated thread {} priority (tracked in handle)",
418 thread_id.0
419 );
420 }
421 }
422
423 crate::println!(
424 "[THREAD] Set thread {} priority to {:?}",
425 thread_id.0,
426 _priority
427 );
428 Ok(())
429 }
430
431 pub fn get_thread_priority(&self, thread_id: ThreadId) -> Result<ThreadPriority, KernelError> {
433 let handle = self
434 .get_thread(thread_id)
435 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
436
437 if let Some(process) = get_process(handle.process_id) {
438 if let Some(thread) = process.get_thread(thread_id) {
439 let priority = match thread.priority {
441 0 => ThreadPriority::RealTime,
442 1 => ThreadPriority::High,
443 2 => ThreadPriority::Normal,
444 3 => ThreadPriority::Low,
445 _ => ThreadPriority::Idle,
446 };
447 return Ok(priority);
448 }
449 }
450
451 Err(KernelError::NotFound {
452 resource: "thread context",
453 id: thread_id.0,
454 })
455 }
456
457 pub fn set_cpu_affinity(&self, thread_id: ThreadId, cpu_mask: u64) -> Result<(), KernelError> {
459 let handle = self
460 .get_thread(thread_id)
461 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
462
463 if let Some(process) = get_process(handle.process_id) {
464 if let Some(thread) = process.get_thread(thread_id) {
465 thread.set_affinity(cpu_mask as usize);
466 crate::println!(
467 "[THREAD] Set thread {} CPU affinity to 0x{:x}",
468 thread_id.0,
469 cpu_mask
470 );
471 return Ok(());
472 }
473 }
474
475 Err(KernelError::NotFound {
476 resource: "thread context",
477 id: thread_id.0,
478 })
479 }
480
481 pub fn create_tls_key(&self, destructor: Option<fn(*mut u8)>) -> Result<TlsKey, KernelError> {
483 let key = TlsKey(self.next_tls_key.fetch_add(1, Ordering::SeqCst) as u32);
484
485 if let Some(dtor) = destructor {
486 self.tls_destructors.write().insert(key, dtor);
487 }
488
489 Ok(key)
490 }
491
492 pub fn delete_tls_key(&self, key: TlsKey) -> Result<(), KernelError> {
494 self.tls_destructors.write().remove(&key);
495
496 let threads = self.threads.read();
498 for handle in threads.values() {
499 if let Some(process) = get_process(handle.process_id) {
500 if let Some(thread) = process.get_thread(handle.id) {
501 #[cfg(feature = "alloc")]
502 {
503 thread.remove_tls_value(key.0 as u64);
504 }
505 }
506 }
507 }
508
509 Ok(())
510 }
511
512 pub fn set_tls_value(
514 &self,
515 thread_id: ThreadId,
516 key: TlsKey,
517 value: *mut u8,
518 ) -> Result<(), KernelError> {
519 let handle = self
520 .get_thread(thread_id)
521 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
522
523 if let Some(process) = get_process(handle.process_id) {
524 if let Some(thread) = process.get_thread(thread_id) {
525 #[cfg(feature = "alloc")]
526 {
527 thread.set_tls_value(key.0 as u64, value as u64);
528 return Ok(());
529 }
530 }
531 }
532
533 Err(KernelError::NotFound {
534 resource: "thread context",
535 id: thread_id.0,
536 })
537 }
538
539 pub fn get_tls_value(&self, thread_id: ThreadId, key: TlsKey) -> Result<*mut u8, KernelError> {
541 let handle = self
542 .get_thread(thread_id)
543 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
544
545 if let Some(process) = get_process(handle.process_id) {
546 if let Some(thread) = process.get_thread(thread_id) {
547 #[cfg(feature = "alloc")]
548 {
549 let value = thread.get_tls_value(key.0 as u64).unwrap_or(0) as *mut u8;
550 return Ok(value);
551 }
552
553 #[cfg(not(feature = "alloc"))]
554 return Ok(core::ptr::null_mut());
555 }
556 }
557
558 Err(KernelError::NotFound {
559 resource: "thread context",
560 id: thread_id.0,
561 })
562 }
563
564 pub fn get_current_thread_id(&self) -> Option<ThreadId> {
566 let tid = crate::sched::get_current_thread_id();
568 if tid != 0 {
569 Some(ThreadId(tid))
570 } else {
571 None
572 }
573 }
574
575 pub fn list_threads(&self) -> Vec<ThreadId> {
577 self.threads.read().keys().copied().collect()
578 }
579
580 pub fn get_thread_stats(&self, thread_id: ThreadId) -> Result<ThreadStats, KernelError> {
582 let handle = self
583 .get_thread(thread_id)
584 .ok_or(KernelError::ThreadNotFound { tid: thread_id.0 })?;
585
586 Ok(ThreadStats {
587 id: thread_id,
588 process_id: handle.process_id,
589 state: handle.get_state(),
590 priority: handle.attributes.priority,
591 cpu_time: handle.get_cpu_time(),
592 context_switches: handle.get_context_switches(),
593 stack_size: handle.attributes.stack_size,
594 name: handle.attributes.name.clone(),
595 })
596 }
597
598 fn run_tls_destructors(&self, thread_id: ThreadId) {
601 let destructors = self.tls_destructors.read();
602
603 if let Some(handle) = self.get_thread(thread_id) {
604 if let Some(process) = get_process(handle.process_id) {
605 if let Some(thread) = process.get_thread(thread_id) {
606 #[cfg(feature = "alloc")]
607 {
608 for (key, dtor) in destructors.iter() {
609 if let Some(value) = thread.get_tls_value(key.0 as u64) {
610 if value != 0 {
611 dtor(value as *mut u8);
612 }
613 }
614 }
615 }
616 }
617 }
618 }
619 }
620}
621
622#[derive(Debug, Clone)]
624pub struct ThreadStats {
625 pub id: ThreadId,
626 pub process_id: ProcessId,
627 pub state: ThreadState,
628 pub priority: ThreadPriority,
629 pub cpu_time: u64,
630 pub context_switches: u64,
631 pub stack_size: usize,
632 pub name: String,
633}
634
635static THREAD_MANAGER: crate::sync::once_lock::OnceLock<ThreadManager> =
637 crate::sync::once_lock::OnceLock::new();
638
639pub fn init() {
641 #[allow(unused_imports)]
642 use crate::println;
643
644 println!("[THREAD_API] Creating new ThreadManager...");
645 match THREAD_MANAGER.set(ThreadManager::new()) {
646 Ok(()) => println!("[THREAD_API] Thread management APIs initialized"),
647 Err(_) => println!("[THREAD_API] Already initialized, skipping..."),
648 }
649}
650
651pub fn try_get_thread_manager() -> Option<&'static ThreadManager> {
655 THREAD_MANAGER.get()
656}
657
658pub fn get_thread_manager() -> &'static ThreadManager {
663 THREAD_MANAGER
664 .get()
665 .expect("Thread manager not initialized: init() was not called")
666}
667
668pub fn create_thread(
672 entry_point: ThreadEntryPoint,
673 arg: *mut u8,
674 attributes: ThreadAttributes,
675 process_id: ProcessId,
676) -> crate::error::KernelResult<Arc<ThreadHandle>> {
677 let params = ThreadCreateParams {
678 entry_point,
679 arg,
680 attributes,
681 };
682
683 get_thread_manager().create_thread(params, process_id)
684}
685
686pub fn join_thread(thread_id: ThreadId) -> crate::error::KernelResult<*mut u8> {
688 get_thread_manager().join_thread(thread_id)
689}
690
691pub fn exit_thread(exit_value: *mut u8) -> ! {
693 let current_thread_id = ThreadId(crate::sched::get_current_thread_id());
695 get_thread_manager().exit_thread(current_thread_id, exit_value);
696
697 crate::sched::yield_cpu();
699
700 loop {
702 core::hint::spin_loop();
703 }
704}
705
706pub fn yield_thread() {
708 crate::sched::yield_cpu();
709}
710
711pub fn sleep_ms(ms: u64) {
716 use crate::arch::timer::get_ticks;
717
718 let start_ticks = get_ticks();
720
721 const TICKS_PER_MS: u64 = 1;
725 let target_ticks = ms.saturating_mul(TICKS_PER_MS);
726 let end_ticks = start_ticks.saturating_add(target_ticks);
727
728 while get_ticks() < end_ticks {
730 yield_thread();
732
733 for _ in 0..100 {
735 core::hint::spin_loop();
736 }
737 }
738}