1use alloc::{
8 collections::BTreeMap,
9 string::{String, ToString},
10 vec::Vec,
11};
12
13use spin::RwLock;
14
15use super::{sync_receive, sync_send, EndpointId, IpcError, Message, SmallMessage};
16use crate::sync::once_lock::OnceLock;
17
18pub type MethodId = u32;
20
21pub type RequestId = u64;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum RpcMessageType {
27 Request = 1,
28 Response = 2,
29 Error = 3,
30}
31
32#[derive(Debug, Clone)]
34pub struct RpcError {
35 pub request_id: RequestId,
36 pub error_code: i32,
37 pub message: String,
38}
39
40impl From<IpcError> for RpcError {
41 fn from(err: IpcError) -> Self {
42 let error_code = match err {
43 IpcError::InvalidCapability => -1,
44 IpcError::PermissionDenied => -2,
45 IpcError::WouldBlock => -3,
46 IpcError::Timeout => -4,
47 IpcError::ProcessNotFound => -5,
48 IpcError::EndpointNotFound => -6,
49 IpcError::MessageTooLarge => -7,
50 IpcError::OutOfMemory => -8,
51 IpcError::RateLimitExceeded => -9,
52 IpcError::InvalidMessage => -10,
53 IpcError::ChannelFull => -11,
54 IpcError::ChannelEmpty => -12,
55 IpcError::EndpointBusy => -13,
56 IpcError::InvalidMemoryRegion => -14,
57 IpcError::ResourceBusy => -15,
58 IpcError::NotInitialized => -16,
59 };
60 RpcError {
61 request_id: 0,
62 error_code,
63 message: alloc::format!("{:?}", err),
64 }
65 }
66}
67
68pub trait RpcService: Send + Sync {
70 fn name(&self) -> &str;
72
73 fn handle_method(&self, method_id: MethodId, params: &[u8]) -> Result<Vec<u8>, RpcError>;
75
76 fn methods(&self) -> Vec<MethodId>;
78}
79
80pub struct RpcClient {
82 endpoint_id: EndpointId,
83 next_request_id: RwLock<u64>,
84}
85
86impl RpcClient {
87 pub fn new(endpoint_id: EndpointId) -> Self {
89 Self {
90 endpoint_id,
91 next_request_id: RwLock::new(1),
92 }
93 }
94
95 pub fn call(&self, method_id: MethodId, params: Vec<u8>) -> Result<Vec<u8>, RpcError> {
99 let request_id = {
100 let mut next = self.next_request_id.write();
101 let id = *next;
102 *next += 1;
103 id
104 };
105
106 if params.len() <= 24 {
109 let mut msg = SmallMessage::new(0, RpcMessageType::Request as u32);
110 msg.data[0] = request_id;
111 msg.data[1] = method_id as u64;
112 msg.data[2] = params.len() as u64;
113
114 for (i, chunk) in params.chunks(8).enumerate() {
116 if i + 3 < 4 {
117 let mut value = 0u64;
119 for (j, &byte) in chunk.iter().enumerate() {
120 value |= (byte as u64) << (j * 8);
121 }
122 msg.data[3] = value;
123 }
124 }
125
126 sync_send(Message::Small(msg), self.endpoint_id)?;
128
129 let response = sync_receive(self.endpoint_id)?;
131
132 match response {
133 Message::Small(resp_msg) => {
134 let resp_len = resp_msg.data[2] as usize;
136 let mut result = Vec::with_capacity(resp_len);
137
138 let value = resp_msg.data[3];
140 for i in 0..resp_len.min(8) {
141 result.push(((value >> (i * 8)) & 0xFF) as u8);
142 }
143
144 Ok(result)
145 }
146 Message::Large(_) => {
147 Err(RpcError {
149 request_id,
150 error_code: -100,
151 message: "Unexpected large response".to_string(),
152 })
153 }
154 }
155 } else {
156 Err(RpcError {
159 request_id,
160 error_code: -101,
161 message: "Large RPC calls not yet implemented".to_string(),
162 })
163 }
164 }
165}
166
167pub struct RpcServer {
169 endpoint_id: EndpointId,
170 services: RwLock<BTreeMap<String, alloc::boxed::Box<dyn RpcService>>>,
171 method_dispatch: RwLock<BTreeMap<MethodId, String>>,
173}
174
175impl RpcServer {
176 pub fn new(endpoint_id: EndpointId) -> Self {
178 Self {
179 endpoint_id,
180 services: RwLock::new(BTreeMap::new()),
181 method_dispatch: RwLock::new(BTreeMap::new()),
182 }
183 }
184
185 pub fn register_service(&self, service: alloc::boxed::Box<dyn RpcService>) {
187 let name = service.name().to_string();
188 let methods = service.methods();
190 let mut dispatch = self.method_dispatch.write();
191 for method_id in methods {
192 dispatch.insert(method_id, name.clone());
193 }
194 drop(dispatch);
195 self.services.write().insert(name, service);
196 }
197
198 pub fn process_requests(&self) -> Result<(), RpcError> {
203 let request = sync_receive(self.endpoint_id).map_err(RpcError::from)?;
205
206 match request {
207 Message::Small(msg) => {
208 let request_id = msg.data[0];
210 let method_id = msg.data[1] as u32;
211 let params_len = msg.data[2] as usize;
212
213 let mut params = Vec::with_capacity(params_len);
215 let value = msg.data[3];
216 for i in 0..params_len.min(8) {
217 params.push(((value >> (i * 8)) & 0xFF) as u8);
218 }
219
220 let dispatch = self.method_dispatch.read();
222 let service_name = dispatch.get(&method_id);
223
224 let mut result = Vec::new();
225 let mut found = false;
226
227 if let Some(name) = service_name {
228 let services = self.services.read();
229 if let Some(service) = services.get(name) {
230 match service.handle_method(method_id, ¶ms) {
231 Ok(response_data) => {
232 result = response_data;
233 found = true;
234 }
235 Err(err) => {
236 let error_msg = SmallMessage::new(0, RpcMessageType::Error as u32)
238 .with_data(0, request_id)
239 .with_data(1, err.error_code as u64);
240
241 sync_send(Message::Small(error_msg), self.endpoint_id)
242 .map_err(RpcError::from)?;
243 return Ok(());
244 }
245 }
246 }
247 }
248
249 if !found {
250 let error_msg = SmallMessage::new(0, RpcMessageType::Error as u32)
252 .with_data(0, request_id)
253 .with_data(1, -404i64 as u64);
254
255 sync_send(Message::Small(error_msg), self.endpoint_id)
256 .map_err(RpcError::from)?;
257 return Ok(());
258 }
259
260 let mut response_msg = SmallMessage::new(0, RpcMessageType::Response as u32);
262 response_msg.data[0] = request_id;
263 response_msg.data[1] = method_id as u64;
264 response_msg.data[2] = result.len() as u64;
265
266 if result.len() <= 8 {
268 let mut value = 0u64;
269 for (i, &byte) in result.iter().enumerate() {
270 value |= (byte as u64) << (i * 8);
271 }
272 response_msg.data[3] = value;
273 }
274
275 sync_send(Message::Small(response_msg), self.endpoint_id)
277 .map_err(RpcError::from)?;
278
279 Ok(())
280 }
281 Message::Large(_) => {
282 Err(RpcError {
284 request_id: 0,
285 error_code: -102,
286 message: "Large RPC requests not yet implemented".to_string(),
287 })
288 }
289 }
290 }
291}
292
293pub struct RpcRegistry {
295 services: RwLock<BTreeMap<String, EndpointId>>,
296}
297
298impl RpcRegistry {
299 pub fn new() -> Self {
301 Self {
302 services: RwLock::new(BTreeMap::new()),
303 }
304 }
305
306 pub fn register(&self, name: String, endpoint: EndpointId) {
308 self.services.write().insert(name, endpoint);
309 }
310
311 pub fn lookup(&self, name: &str) -> Option<EndpointId> {
313 self.services.read().get(name).copied()
314 }
315
316 pub fn list_services(&self) -> Vec<String> {
318 self.services.read().keys().cloned().collect()
319 }
320}
321
322impl Default for RpcRegistry {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328static GLOBAL_REGISTRY: RwLock<Option<RpcRegistry>> = RwLock::new(None);
330
331pub fn init() {
333 *GLOBAL_REGISTRY.write() = Some(RpcRegistry::new());
334 crate::println!("[RPC] RPC framework initialized (stub)");
335}
336
337static REGISTRY_STORAGE: OnceLock<RpcRegistry> = OnceLock::new();
339
340pub fn get_registry() -> &'static RpcRegistry {
342 REGISTRY_STORAGE.get_or_init(RpcRegistry::new)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_rpc_registry() {
351 let registry = RpcRegistry::new();
352 registry.register(String::from("test_service"), 42u64);
353
354 let found = registry.lookup("test_service");
355 assert_eq!(found, Some(42u64));
356 }
357}