1use alloc::vec::Vec;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub(crate) enum InflateError {
11 InvalidHeader,
12 InvalidBlock,
13 BufferOverflow,
14 IncompleteInput,
15 InvalidDistance,
16 InvalidLitLen,
17}
18
19pub(crate) fn check_zlib_header(data: &[u8]) -> Result<usize, InflateError> {
21 if data.len() < 2 {
22 return Err(InflateError::IncompleteInput);
23 }
24
25 let cmf = data[0];
26 let flg = data[1];
27
28 if (cmf & 0x0F) != 8 {
30 return Err(InflateError::InvalidHeader);
31 }
32
33 let check = (cmf as u16) * 256 + (flg as u16);
35 if !check.is_multiple_of(31) {
36 return Err(InflateError::InvalidHeader);
37 }
38
39 let dict_present = (flg & 0x20) != 0;
41 let header_size = if dict_present { 6 } else { 2 };
42
43 Ok(header_size)
44}
45
46struct BitReader<'a> {
48 data: &'a [u8],
49 pos: usize,
50 bit_pos: u8,
51}
52
53impl<'a> BitReader<'a> {
54 fn new(data: &'a [u8]) -> Self {
55 Self {
56 data,
57 pos: 0,
58 bit_pos: 0,
59 }
60 }
61
62 fn read_bits(&mut self, count: u8) -> Result<u32, InflateError> {
63 let mut val: u32 = 0;
64 for i in 0..count {
65 if self.pos >= self.data.len() {
66 return Err(InflateError::IncompleteInput);
67 }
68 let bit = (self.data[self.pos] >> self.bit_pos) & 1;
69 val |= (bit as u32) << i;
70 self.bit_pos += 1;
71 if self.bit_pos >= 8 {
72 self.bit_pos = 0;
73 self.pos += 1;
74 }
75 }
76 Ok(val)
77 }
78
79 fn align_byte(&mut self) {
80 if self.bit_pos > 0 {
81 self.bit_pos = 0;
82 self.pos += 1;
83 }
84 }
85
86 fn read_u16_le(&mut self) -> Result<u16, InflateError> {
87 self.align_byte();
88 if self.pos + 2 > self.data.len() {
89 return Err(InflateError::IncompleteInput);
90 }
91 let val = u16::from_le_bytes([self.data[self.pos], self.data[self.pos + 1]]);
92 self.pos += 2;
93 Ok(val)
94 }
95
96 fn bytes_consumed(&self) -> usize {
97 if self.bit_pos > 0 {
98 self.pos + 1
99 } else {
100 self.pos
101 }
102 }
103}
104
105const LENGTH_BASE: [u16; 29] = [
107 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131,
108 163, 195, 227, 258,
109];
110
111const LENGTH_EXTRA: [u8; 29] = [
113 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0,
114];
115
116const DIST_BASE: [u16; 30] = [
118 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537,
119 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
120];
121
122const DIST_EXTRA: [u8; 30] = [
124 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13,
125 13,
126];
127
128pub(crate) fn inflate_raw(data: &[u8], max_output: usize) -> Result<Vec<u8>, InflateError> {
130 let mut reader = BitReader::new(data);
131 let mut output = Vec::with_capacity(core::cmp::min(max_output, 65536));
132
133 loop {
134 let bfinal = reader.read_bits(1)?;
135 let btype = reader.read_bits(2)?;
136
137 match btype {
138 0 => {
140 let len = reader.read_u16_le()?;
141 let _nlen = reader.read_u16_le()?;
142
143 for _ in 0..len {
144 if output.len() >= max_output {
145 return Err(InflateError::BufferOverflow);
146 }
147 if reader.pos >= reader.data.len() {
148 return Err(InflateError::IncompleteInput);
149 }
150 output.push(reader.data[reader.pos]);
151 reader.pos += 1;
152 }
153 }
154 1 => {
156 inflate_fixed_huffman(&mut reader, &mut output, max_output)?;
157 }
158 2 => {
160 inflate_dynamic_huffman(&mut reader, &mut output, max_output)?;
161 }
162 _ => return Err(InflateError::InvalidBlock),
163 }
164
165 if bfinal != 0 {
166 break;
167 }
168 }
169
170 Ok(output)
171}
172
173pub(crate) fn inflate_zlib(data: &[u8], max_output: usize) -> Result<Vec<u8>, InflateError> {
175 let header_size = check_zlib_header(data)?;
176 inflate_raw(&data[header_size..], max_output)
177}
178
179fn decode_fixed_litlen(reader: &mut BitReader) -> Result<u16, InflateError> {
180 let mut code: u32 = 0;
187 for bits in 0..9u8 {
188 let bit = reader.read_bits(1)?;
189 code = (code << 1) | bit;
191
192 match bits + 1 {
193 7 => {
194 if code <= 0b0010111 {
195 return Ok((code + 256) as u16);
196 }
197 }
198 8 => {
199 if (0b00110000..=0b10111111).contains(&code) {
200 return Ok((code - 0b00110000) as u16);
201 }
202 if (0b11000000..=0b11000111).contains(&code) {
203 return Ok((code - 0b11000000 + 280) as u16);
204 }
205 }
206 9 => {
207 if (0b110010000..=0b111111111).contains(&code) {
208 return Ok((code - 0b110010000 + 144) as u16);
209 }
210 }
211 _ => {}
212 }
213 }
214
215 Err(InflateError::InvalidLitLen)
216}
217
218fn inflate_fixed_huffman(
219 reader: &mut BitReader,
220 output: &mut Vec<u8>,
221 max_output: usize,
222) -> Result<(), InflateError> {
223 loop {
224 let lit = decode_fixed_litlen(reader)?;
225
226 if lit < 256 {
227 if output.len() >= max_output {
228 return Err(InflateError::BufferOverflow);
229 }
230 output.push(lit as u8);
231 } else if lit == 256 {
232 return Ok(());
233 } else {
234 let len_idx = (lit - 257) as usize;
235 if len_idx >= LENGTH_BASE.len() {
236 return Err(InflateError::InvalidLitLen);
237 }
238 let length =
239 LENGTH_BASE[len_idx] as usize + reader.read_bits(LENGTH_EXTRA[len_idx])? as usize;
240
241 let dist_code = reader.read_bits(5)? as usize;
243 let dist_code = reverse_bits(dist_code as u32, 5) as usize;
245 if dist_code >= DIST_BASE.len() {
246 return Err(InflateError::InvalidDistance);
247 }
248 let distance =
249 DIST_BASE[dist_code] as usize + reader.read_bits(DIST_EXTRA[dist_code])? as usize;
250
251 if distance > output.len() {
252 return Err(InflateError::InvalidDistance);
253 }
254
255 for _ in 0..length {
256 if output.len() >= max_output {
257 return Err(InflateError::BufferOverflow);
258 }
259 let idx = output.len() - distance;
260 output.push(output[idx]);
261 }
262 }
263 }
264}
265
266fn inflate_dynamic_huffman(
267 reader: &mut BitReader,
268 output: &mut Vec<u8>,
269 max_output: usize,
270) -> Result<(), InflateError> {
271 let hlit = reader.read_bits(5)? as usize + 257;
272 let hdist = reader.read_bits(5)? as usize + 1;
273 let hclen = reader.read_bits(4)? as usize + 4;
274
275 const CL_ORDER: [usize; 19] = [
277 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
278 ];
279
280 let mut cl_lengths = [0u8; 19];
281 for i in 0..hclen {
282 cl_lengths[CL_ORDER[i]] = reader.read_bits(3)? as u8;
283 }
284
285 let cl_table = build_huffman_table(&cl_lengths)?;
287
288 let total = hlit + hdist;
290 let mut code_lengths = Vec::with_capacity(total);
291
292 while code_lengths.len() < total {
293 let sym = decode_huffman(reader, &cl_table)?;
294 match sym {
295 0..=15 => code_lengths.push(sym as u8),
296 16 => {
297 let repeat = reader.read_bits(2)? as usize + 3;
298 let prev = *code_lengths.last().ok_or(InflateError::InvalidBlock)?;
299 code_lengths.resize(code_lengths.len() + repeat, prev);
300 }
301 17 => {
302 let repeat = reader.read_bits(3)? as usize + 3;
303 code_lengths.resize(code_lengths.len() + repeat, 0);
304 }
305 18 => {
306 let repeat = reader.read_bits(7)? as usize + 11;
307 code_lengths.resize(code_lengths.len() + repeat, 0);
308 }
309 _ => return Err(InflateError::InvalidBlock),
310 }
311 }
312
313 let litlen_lengths = &code_lengths[..hlit];
314 let dist_lengths = &code_lengths[hlit..];
315
316 let litlen_table = build_huffman_table(litlen_lengths)?;
317 let dist_table = build_huffman_table(dist_lengths)?;
318
319 loop {
321 let sym = decode_huffman(reader, &litlen_table)?;
322
323 if sym < 256 {
324 if output.len() >= max_output {
325 return Err(InflateError::BufferOverflow);
326 }
327 output.push(sym as u8);
328 } else if sym == 256 {
329 return Ok(());
330 } else {
331 let len_idx = (sym - 257) as usize;
332 if len_idx >= LENGTH_BASE.len() {
333 return Err(InflateError::InvalidLitLen);
334 }
335 let length =
336 LENGTH_BASE[len_idx] as usize + reader.read_bits(LENGTH_EXTRA[len_idx])? as usize;
337
338 let dist_sym = decode_huffman(reader, &dist_table)? as usize;
339 if dist_sym >= DIST_BASE.len() {
340 return Err(InflateError::InvalidDistance);
341 }
342 let distance =
343 DIST_BASE[dist_sym] as usize + reader.read_bits(DIST_EXTRA[dist_sym])? as usize;
344
345 if distance > output.len() {
346 return Err(InflateError::InvalidDistance);
347 }
348
349 for _ in 0..length {
350 if output.len() >= max_output {
351 return Err(InflateError::BufferOverflow);
352 }
353 let idx = output.len() - distance;
354 output.push(output[idx]);
355 }
356 }
357 }
358}
359
360#[derive(Debug, Clone, Copy, Default)]
362struct HuffEntry {
363 symbol: u16,
364 length: u8,
365}
366
367struct HuffTable {
369 entries: Vec<HuffEntry>,
370 max_bits: u8,
371}
372
373fn build_huffman_table(lengths: &[u8]) -> Result<HuffTable, InflateError> {
374 let max_bits = *lengths.iter().max().unwrap_or(&0);
375 if max_bits == 0 {
376 return Ok(HuffTable {
377 entries: Vec::new(),
378 max_bits: 0,
379 });
380 }
381
382 let table_size = 1usize << max_bits;
383 let mut entries = alloc::vec![HuffEntry::default(); table_size];
384
385 let mut bl_count = [0u16; 16];
387 for &len in lengths {
388 bl_count[len as usize] += 1;
389 }
390 bl_count[0] = 0;
391
392 let mut next_code = [0u16; 16];
394 let mut code: u16 = 0;
395 for bits in 1..=max_bits {
396 code = (code + bl_count[bits as usize - 1]) << 1;
397 next_code[bits as usize] = code;
398 }
399
400 for (sym, &len) in lengths.iter().enumerate() {
402 if len == 0 {
403 continue;
404 }
405 let code = next_code[len as usize];
406 next_code[len as usize] += 1;
407
408 let reversed = reverse_bits(code as u32, len) as usize;
410 let fill_count = 1usize << (max_bits - len);
411 for i in 0..fill_count {
412 let idx = reversed | (i << len);
413 if idx < table_size {
414 entries[idx] = HuffEntry {
415 symbol: sym as u16,
416 length: len,
417 };
418 }
419 }
420 }
421
422 Ok(HuffTable { entries, max_bits })
423}
424
425fn decode_huffman(reader: &mut BitReader, table: &HuffTable) -> Result<u16, InflateError> {
426 if table.max_bits == 0 {
427 return Err(InflateError::InvalidBlock);
428 }
429
430 let bits = reader.read_bits(table.max_bits)? as usize;
431 let entry = &table.entries[bits];
432 if entry.length == 0 {
433 return Err(InflateError::InvalidBlock);
434 }
435
436 let unused = table.max_bits - entry.length;
438 if unused > 0 {
439 let total_bits = reader.pos * 8 + reader.bit_pos as usize;
441 let new_total = total_bits - unused as usize;
442 reader.pos = new_total / 8;
443 reader.bit_pos = (new_total % 8) as u8;
444 }
445
446 Ok(entry.symbol)
447}
448
449fn reverse_bits(val: u32, bits: u8) -> u32 {
450 let mut result = 0u32;
451 let mut v = val;
452 for _ in 0..bits {
453 result = (result << 1) | (v & 1);
454 v >>= 1;
455 }
456 result
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_check_zlib_header_valid() {
469 let result = check_zlib_header(&[0x78, 0x01]);
471 assert!(result.is_ok());
472 assert_eq!(result.unwrap(), 2);
473 }
474
475 #[test]
476 fn test_check_zlib_header_9c() {
477 let result = check_zlib_header(&[0x78, 0x9C]);
478 assert!(result.is_ok());
479 }
480
481 #[test]
482 fn test_check_zlib_header_too_short() {
483 let result = check_zlib_header(&[0x78]);
484 assert_eq!(result, Err(InflateError::IncompleteInput));
485 }
486
487 #[test]
488 fn test_check_zlib_header_invalid_method() {
489 let result = check_zlib_header(&[0x77, 0x01]);
490 assert_eq!(result, Err(InflateError::InvalidHeader));
491 }
492
493 #[test]
494 fn test_reverse_bits() {
495 assert_eq!(reverse_bits(0b110, 3), 0b011);
496 assert_eq!(reverse_bits(0b1010, 4), 0b0101);
497 assert_eq!(reverse_bits(0b1, 1), 0b1);
498 }
499
500 #[test]
501 fn test_inflate_raw_stored_block() {
502 let data = [
504 0x01, 0x05, 0x00, 0xFA, 0xFF, b'h', b'e', b'l', b'l', b'o',
508 ];
509 let result = inflate_raw(&data, 1024);
510 assert!(result.is_ok());
511 assert_eq!(&result.unwrap(), b"hello");
512 }
513
514 #[test]
515 fn test_inflate_error_types() {
516 assert_eq!(InflateError::InvalidHeader, InflateError::InvalidHeader);
517 assert_ne!(InflateError::InvalidHeader, InflateError::InvalidBlock);
518 }
519
520 #[test]
521 fn test_build_huffman_table_empty() {
522 let lengths: [u8; 0] = [];
523 let table = build_huffman_table(&lengths).unwrap();
524 assert_eq!(table.max_bits, 0);
525 }
526
527 #[test]
528 fn test_bit_reader_basic() {
529 let data = [0b10110100u8];
530 let mut reader = BitReader::new(&data);
531
532 assert_eq!(reader.read_bits(1).unwrap(), 0); assert_eq!(reader.read_bits(1).unwrap(), 0); assert_eq!(reader.read_bits(1).unwrap(), 1); assert_eq!(reader.read_bits(1).unwrap(), 0); assert_eq!(reader.read_bits(1).unwrap(), 1); }
538
539 #[test]
540 fn test_bit_reader_multi_bit() {
541 let data = [0xFF];
542 let mut reader = BitReader::new(&data);
543 assert_eq!(reader.read_bits(4).unwrap(), 0xF);
544 assert_eq!(reader.read_bits(4).unwrap(), 0xF);
545 }
546
547 #[test]
548 fn test_bit_reader_overflow() {
549 let data = [0x00];
550 let mut reader = BitReader::new(&data);
551 let _ = reader.read_bits(8); let result = reader.read_bits(1);
553 assert!(result.is_err());
554 }
555
556 #[test]
557 fn test_inflate_buffer_overflow() {
558 let data = [0x01, 0x05, 0x00, 0xFA, 0xFF, b'h', b'e', b'l', b'l', b'o'];
559 let result = inflate_raw(&data, 3); assert_eq!(result, Err(InflateError::BufferOverflow));
561 }
562}