1#![allow(dead_code)]
7
8use alloc::{string::String, vec::Vec};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum FrameType {
17 Data,
19 Headers,
21 Settings,
23 WindowUpdate,
25 GoAway,
27 Unknown(u8),
29}
30
31impl FrameType {
32 pub fn from_byte(b: u8) -> Self {
34 match b {
35 0x0 => FrameType::Data,
36 0x1 => FrameType::Headers,
37 0x4 => FrameType::Settings,
38 0x7 => FrameType::GoAway,
39 0x8 => FrameType::WindowUpdate,
40 other => FrameType::Unknown(other),
41 }
42 }
43
44 pub fn to_byte(self) -> u8 {
46 match self {
47 FrameType::Data => 0x0,
48 FrameType::Headers => 0x1,
49 FrameType::Settings => 0x4,
50 FrameType::GoAway => 0x7,
51 FrameType::WindowUpdate => 0x8,
52 FrameType::Unknown(b) => b,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct Http2Frame {
60 pub length: u32,
62 pub frame_type: FrameType,
64 pub flags: u8,
66 pub stream_id: u32,
68 pub payload: Vec<u8>,
70}
71
72impl Http2Frame {
73 pub const HEADER_SIZE: usize = 9;
75
76 pub fn parse(data: &[u8]) -> Option<(Self, usize)> {
81 if data.len() < Self::HEADER_SIZE {
82 return None;
83 }
84
85 let length = ((data[0] as u32) << 16) | ((data[1] as u32) << 8) | (data[2] as u32);
86 let frame_type = FrameType::from_byte(data[3]);
87 let flags = data[4];
88 let stream_id = ((data[5] as u32 & 0x7F) << 24)
89 | ((data[6] as u32) << 16)
90 | ((data[7] as u32) << 8)
91 | (data[8] as u32);
92
93 let total_len = Self::HEADER_SIZE + length as usize;
94 if data.len() < total_len {
95 return None;
96 }
97
98 let payload = data[Self::HEADER_SIZE..total_len].to_vec();
99
100 Some((
101 Http2Frame {
102 length,
103 frame_type,
104 flags,
105 stream_id,
106 payload,
107 },
108 total_len,
109 ))
110 }
111
112 pub fn serialize(&self) -> Vec<u8> {
114 let mut buf = Vec::with_capacity(Self::HEADER_SIZE + self.payload.len());
115 buf.push(((self.length >> 16) & 0xFF) as u8);
117 buf.push(((self.length >> 8) & 0xFF) as u8);
118 buf.push((self.length & 0xFF) as u8);
119 buf.push(self.frame_type.to_byte());
121 buf.push(self.flags);
123 buf.push(((self.stream_id >> 24) & 0x7F) as u8);
125 buf.push(((self.stream_id >> 16) & 0xFF) as u8);
126 buf.push(((self.stream_id >> 8) & 0xFF) as u8);
127 buf.push((self.stream_id & 0xFF) as u8);
128 buf.extend_from_slice(&self.payload);
130 buf
131 }
132
133 pub fn data(stream_id: u32, payload: Vec<u8>, end_stream: bool) -> Self {
135 let flags = if end_stream { 0x01 } else { 0x00 };
136 Http2Frame {
137 length: payload.len() as u32,
138 frame_type: FrameType::Data,
139 flags,
140 stream_id,
141 payload,
142 }
143 }
144
145 pub fn settings(stream_id: u32, payload: Vec<u8>) -> Self {
147 Http2Frame {
148 length: payload.len() as u32,
149 frame_type: FrameType::Settings,
150 flags: 0,
151 stream_id,
152 payload,
153 }
154 }
155
156 pub fn settings_ack() -> Self {
158 Http2Frame {
159 length: 0,
160 frame_type: FrameType::Settings,
161 flags: 0x01,
162 stream_id: 0,
163 payload: Vec::new(),
164 }
165 }
166
167 pub fn window_update(stream_id: u32, increment: u32) -> Self {
169 let payload = alloc::vec![
170 ((increment >> 24) & 0x7F) as u8,
171 ((increment >> 16) & 0xFF) as u8,
172 ((increment >> 8) & 0xFF) as u8,
173 (increment & 0xFF) as u8,
174 ];
175 Http2Frame {
176 length: 4,
177 frame_type: FrameType::WindowUpdate,
178 flags: 0,
179 stream_id,
180 payload,
181 }
182 }
183
184 pub fn goaway(last_stream_id: u32, error_code: u32) -> Self {
186 let payload = alloc::vec![
187 ((last_stream_id >> 24) & 0x7F) as u8,
188 ((last_stream_id >> 16) & 0xFF) as u8,
189 ((last_stream_id >> 8) & 0xFF) as u8,
190 (last_stream_id & 0xFF) as u8,
191 ((error_code >> 24) & 0xFF) as u8,
192 ((error_code >> 16) & 0xFF) as u8,
193 ((error_code >> 8) & 0xFF) as u8,
194 (error_code & 0xFF) as u8,
195 ];
196 Http2Frame {
197 length: 8,
198 frame_type: FrameType::GoAway,
199 flags: 0,
200 stream_id: 0,
201 payload,
202 }
203 }
204
205 pub fn is_end_stream(&self) -> bool {
207 self.flags & 0x01 != 0
208 }
209
210 pub fn is_end_headers(&self) -> bool {
212 self.flags & 0x04 != 0
213 }
214}
215
216pub const HPACK_STATIC_TABLE: &[(&str, &str); 61] = &[
222 (":authority", ""),
224 (":method", "GET"),
226 (":method", "POST"),
228 (":path", "/"),
230 (":path", "/index.html"),
232 (":scheme", "http"),
234 (":scheme", "https"),
236 (":status", "200"),
238 (":status", "204"),
240 (":status", "206"),
242 (":status", "304"),
244 (":status", "400"),
246 (":status", "404"),
248 (":status", "500"),
250 ("accept-charset", ""),
252 ("accept-encoding", "gzip, deflate"),
254 ("accept-language", ""),
256 ("accept-ranges", ""),
258 ("accept", ""),
260 ("access-control-allow-origin", ""),
262 ("age", ""),
264 ("allow", ""),
266 ("authorization", ""),
268 ("cache-control", ""),
270 ("content-disposition", ""),
272 ("content-encoding", ""),
274 ("content-language", ""),
276 ("content-length", ""),
278 ("content-location", ""),
280 ("content-range", ""),
282 ("content-type", ""),
284 ("cookie", ""),
286 ("date", ""),
288 ("etag", ""),
290 ("expect", ""),
292 ("expires", ""),
294 ("from", ""),
296 ("host", ""),
298 ("if-match", ""),
300 ("if-modified-since", ""),
302 ("if-none-match", ""),
304 ("if-range", ""),
306 ("if-unmodified-since", ""),
308 ("last-modified", ""),
310 ("link", ""),
312 ("location", ""),
314 ("max-forwards", ""),
316 ("proxy-authenticate", ""),
318 ("proxy-authorization", ""),
320 ("range", ""),
322 ("referer", ""),
324 ("refresh", ""),
326 ("retry-after", ""),
328 ("server", ""),
330 ("set-cookie", ""),
332 ("strict-transport-security", ""),
334 ("transfer-encoding", ""),
336 ("user-agent", ""),
338 ("vary", ""),
340 ("via", ""),
342 ("www-authenticate", ""),
344];
345
346pub fn hpack_static_lookup(index: usize) -> Option<(&'static str, &'static str)> {
348 if index == 0 || index > HPACK_STATIC_TABLE.len() {
349 return None;
350 }
351 Some(HPACK_STATIC_TABLE[index - 1])
352}
353
354pub fn hpack_static_find_name(name: &str) -> Option<usize> {
356 for (i, (n, _)) in HPACK_STATIC_TABLE.iter().enumerate() {
357 if *n == name {
358 return Some(i + 1);
359 }
360 }
361 None
362}
363
364#[derive(Debug, Clone)]
370pub struct GrpcMessage {
371 pub service: String,
373 pub method: String,
375 pub payload: Vec<u8>,
377}
378
379impl GrpcMessage {
380 pub fn new(service: String, method: String, payload: Vec<u8>) -> Self {
382 GrpcMessage {
383 service,
384 method,
385 payload,
386 }
387 }
388
389 pub fn encode_payload(&self) -> Vec<u8> {
393 let len = self.payload.len() as u32;
394 let mut buf = Vec::with_capacity(5 + self.payload.len());
395 buf.push(0);
397 buf.push(((len >> 24) & 0xFF) as u8);
399 buf.push(((len >> 16) & 0xFF) as u8);
400 buf.push(((len >> 8) & 0xFF) as u8);
401 buf.push((len & 0xFF) as u8);
402 buf.extend_from_slice(&self.payload);
404 buf
405 }
406
407 pub fn decode_payload(data: &[u8]) -> Option<(Vec<u8>, usize)> {
412 if data.len() < 5 {
413 return None;
414 }
415
416 let _compressed = data[0];
417 let length = ((data[1] as u32) << 24)
418 | ((data[2] as u32) << 16)
419 | ((data[3] as u32) << 8)
420 | (data[4] as u32);
421
422 let total = 5 + length as usize;
423 if data.len() < total {
424 return None;
425 }
426
427 let payload = data[5..total].to_vec();
428 Some((payload, total))
429 }
430
431 pub fn path(&self) -> String {
435 let mut p = String::with_capacity(2 + self.service.len() + self.method.len());
436 p.push('/');
437 p.push_str(&self.service);
438 p.push('/');
439 p.push_str(&self.method);
440 p
441 }
442}
443
444#[derive(Debug, Clone, Copy, PartialEq, Eq)]
450pub enum GrpcStatus {
451 Ok,
453 Cancelled,
455 Unknown,
457 InvalidArgument,
459 DeadlineExceeded,
461 NotFound,
463 AlreadyExists,
465 PermissionDenied,
467 ResourceExhausted,
469 Unimplemented,
471 Internal,
473 Unavailable,
475}
476
477impl GrpcStatus {
478 pub fn code(self) -> u32 {
480 match self {
481 GrpcStatus::Ok => 0,
482 GrpcStatus::Cancelled => 1,
483 GrpcStatus::Unknown => 2,
484 GrpcStatus::InvalidArgument => 3,
485 GrpcStatus::DeadlineExceeded => 4,
486 GrpcStatus::NotFound => 5,
487 GrpcStatus::AlreadyExists => 6,
488 GrpcStatus::PermissionDenied => 7,
489 GrpcStatus::ResourceExhausted => 8,
490 GrpcStatus::Unimplemented => 12,
491 GrpcStatus::Internal => 13,
492 GrpcStatus::Unavailable => 14,
493 }
494 }
495
496 pub fn from_code(code: u32) -> Self {
498 match code {
499 0 => GrpcStatus::Ok,
500 1 => GrpcStatus::Cancelled,
501 3 => GrpcStatus::InvalidArgument,
502 4 => GrpcStatus::DeadlineExceeded,
503 5 => GrpcStatus::NotFound,
504 6 => GrpcStatus::AlreadyExists,
505 7 => GrpcStatus::PermissionDenied,
506 8 => GrpcStatus::ResourceExhausted,
507 12 => GrpcStatus::Unimplemented,
508 13 => GrpcStatus::Internal,
509 14 => GrpcStatus::Unavailable,
510 _ => GrpcStatus::Unknown,
511 }
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct GrpcResponse {
518 pub status: GrpcStatus,
520 pub payload: Vec<u8>,
522 pub message: String,
524}
525
526impl GrpcResponse {
527 pub fn ok(payload: Vec<u8>) -> Self {
529 GrpcResponse {
530 status: GrpcStatus::Ok,
531 payload,
532 message: String::new(),
533 }
534 }
535
536 pub fn error(status: GrpcStatus, message: String) -> Self {
538 GrpcResponse {
539 status,
540 payload: Vec::new(),
541 message,
542 }
543 }
544}
545
546#[derive(Debug)]
548pub struct GrpcTransport {
549 socket_path: String,
551 next_stream_id: u32,
553 max_frame_size: u32,
555 window_size: u32,
557}
558
559impl GrpcTransport {
560 pub const DEFAULT_MAX_FRAME_SIZE: u32 = 16384;
562 pub const DEFAULT_WINDOW_SIZE: u32 = 65535;
564
565 pub fn new(socket_path: String) -> Self {
567 GrpcTransport {
568 socket_path,
569 next_stream_id: 1,
570 max_frame_size: Self::DEFAULT_MAX_FRAME_SIZE,
571 window_size: Self::DEFAULT_WINDOW_SIZE,
572 }
573 }
574
575 pub fn socket_path(&self) -> &str {
577 &self.socket_path
578 }
579
580 pub fn next_stream(&mut self) -> u32 {
582 let id = self.next_stream_id;
583 self.next_stream_id += 2;
584 id
585 }
586
587 pub fn build_request_frame(&mut self, msg: &GrpcMessage) -> Http2Frame {
589 let stream_id = self.next_stream();
590 let encoded = msg.encode_payload();
591 Http2Frame::data(stream_id, encoded, true)
592 }
593
594 pub fn parse_response_frame(frame: &Http2Frame) -> Option<GrpcResponse> {
596 if frame.frame_type != FrameType::Data {
597 return None;
598 }
599
600 match GrpcMessage::decode_payload(&frame.payload) {
601 Some((payload, _)) => Some(GrpcResponse::ok(payload)),
602 None => Some(GrpcResponse::error(
603 GrpcStatus::Internal,
604 String::from("failed to decode gRPC payload"),
605 )),
606 }
607 }
608
609 pub fn window_size(&self) -> u32 {
611 self.window_size
612 }
613
614 pub fn update_window(&mut self, delta: u32) {
616 self.window_size = self.window_size.saturating_add(delta);
617 }
618
619 pub fn consume_window(&mut self, amount: u32) -> bool {
621 if self.window_size >= amount {
622 self.window_size -= amount;
623 true
624 } else {
625 false
626 }
627 }
628}
629
630#[cfg(test)]
635mod tests {
636 #[allow(unused_imports)]
637 use alloc::string::ToString;
638 #[allow(unused_imports)]
639 use alloc::vec;
640
641 use super::*;
642
643 #[test]
644 fn test_frame_type_roundtrip() {
645 for byte in 0..=8u8 {
646 let ft = FrameType::from_byte(byte);
647 assert_eq!(ft.to_byte(), byte);
648 }
649 }
650
651 #[test]
652 fn test_http2_frame_parse_too_small() {
653 let data = [0u8; 5];
654 assert!(Http2Frame::parse(&data).is_none());
655 }
656
657 #[test]
658 fn test_http2_frame_parse_roundtrip() {
659 let original = Http2Frame::data(3, vec![0xDE, 0xAD, 0xBE, 0xEF], true);
660 let serialized = original.serialize();
661 let (parsed, consumed) = Http2Frame::parse(&serialized).unwrap();
662 assert_eq!(consumed, serialized.len());
663 assert_eq!(parsed.stream_id, 3);
664 assert_eq!(parsed.frame_type, FrameType::Data);
665 assert!(parsed.is_end_stream());
666 assert_eq!(parsed.payload, vec![0xDE, 0xAD, 0xBE, 0xEF]);
667 }
668
669 #[test]
670 fn test_http2_settings_ack() {
671 let frame = Http2Frame::settings_ack();
672 assert_eq!(frame.frame_type, FrameType::Settings);
673 assert_eq!(frame.flags, 0x01);
674 assert_eq!(frame.stream_id, 0);
675 assert!(frame.payload.is_empty());
676 }
677
678 #[test]
679 fn test_http2_window_update() {
680 let frame = Http2Frame::window_update(1, 32768);
681 assert_eq!(frame.frame_type, FrameType::WindowUpdate);
682 assert_eq!(frame.length, 4);
683 let val = ((frame.payload[0] as u32 & 0x7F) << 24)
684 | ((frame.payload[1] as u32) << 16)
685 | ((frame.payload[2] as u32) << 8)
686 | (frame.payload[3] as u32);
687 assert_eq!(val, 32768);
688 }
689
690 #[test]
691 fn test_http2_goaway() {
692 let frame = Http2Frame::goaway(5, 0);
693 assert_eq!(frame.frame_type, FrameType::GoAway);
694 assert_eq!(frame.length, 8);
695 assert_eq!(frame.stream_id, 0);
696 }
697
698 #[test]
699 fn test_hpack_static_lookup() {
700 let (name, val) = hpack_static_lookup(2).unwrap();
701 assert_eq!(name, ":method");
702 assert_eq!(val, "GET");
703 assert!(hpack_static_lookup(0).is_none());
704 assert!(hpack_static_lookup(62).is_none());
705 }
706
707 #[test]
708 fn test_hpack_static_find_name() {
709 assert_eq!(hpack_static_find_name(":authority"), Some(1));
710 assert_eq!(hpack_static_find_name(":method"), Some(2));
711 assert_eq!(hpack_static_find_name("content-type"), Some(31));
712 assert!(hpack_static_find_name("x-custom").is_none());
713 }
714
715 #[test]
716 fn test_grpc_message_encode_decode() {
717 let msg = GrpcMessage::new(
718 String::from("runtime.v1.RuntimeService"),
719 String::from("RunPodSandbox"),
720 vec![1, 2, 3, 4],
721 );
722 let encoded = msg.encode_payload();
723 assert_eq!(encoded[0], 0); let decoded_len = ((encoded[1] as u32) << 24)
725 | ((encoded[2] as u32) << 16)
726 | ((encoded[3] as u32) << 8)
727 | (encoded[4] as u32);
728 assert_eq!(decoded_len, 4);
729
730 let (payload, consumed) = GrpcMessage::decode_payload(&encoded).unwrap();
731 assert_eq!(payload, vec![1, 2, 3, 4]);
732 assert_eq!(consumed, 9);
733 }
734
735 #[test]
736 fn test_grpc_message_path() {
737 let msg = GrpcMessage::new(
738 String::from("runtime.v1.RuntimeService"),
739 String::from("RunPodSandbox"),
740 Vec::new(),
741 );
742 assert_eq!(msg.path(), "/runtime.v1.RuntimeService/RunPodSandbox");
743 }
744
745 #[test]
746 fn test_grpc_status_roundtrip() {
747 for code in [0u32, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14] {
748 let status = GrpcStatus::from_code(code);
749 assert_eq!(status.code(), code);
750 }
751 }
752
753 #[test]
754 fn test_grpc_transport_stream_ids() {
755 let mut transport = GrpcTransport::new(String::from("/run/cri.sock"));
756 assert_eq!(transport.next_stream(), 1);
757 assert_eq!(transport.next_stream(), 3);
758 assert_eq!(transport.next_stream(), 5);
759 }
760}