⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/media/image_codecs/
png.rs

1//! PNG decoder with full DEFLATE/zlib decompression.
2//!
3//! Supports all critical chunks (IHDR, PLTE, IDAT, IEND), ancillary chunks
4//! (tRNS, gAMA), all 5 filter types, Adam7 interlacing, color types
5//! 0/2/3/4/6, and bit depths 1/2/4/8/16.
6
7#![allow(dead_code)]
8
9use alloc::vec::Vec;
10
11use super::{read_be_u16, read_be_u32, DecodedImage, ImageCodecError};
12
13// ============================================================================
14// PNG DECODER
15// ============================================================================
16
17/// PNG 8-byte signature.
18pub(crate) const PNG_SIGNATURE: [u8; 8] = [137, 80, 78, 71, 13, 10, 26, 10];
19
20/// PNG color types.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PngColorType {
23    Grayscale = 0,
24    Rgb = 2,
25    Indexed = 3,
26    GrayscaleAlpha = 4,
27    Rgba = 6,
28}
29
30impl PngColorType {
31    fn from_u8(v: u8) -> Result<Self, ImageCodecError> {
32        match v {
33            0 => Ok(Self::Grayscale),
34            2 => Ok(Self::Rgb),
35            3 => Ok(Self::Indexed),
36            4 => Ok(Self::GrayscaleAlpha),
37            6 => Ok(Self::Rgba),
38            _ => Err(ImageCodecError::Unsupported),
39        }
40    }
41
42    /// Number of channels (samples per pixel).
43    fn channels(self) -> usize {
44        match self {
45            Self::Grayscale => 1,
46            Self::Rgb => 3,
47            Self::Indexed => 1,
48            Self::GrayscaleAlpha => 2,
49            Self::Rgba => 4,
50        }
51    }
52}
53
54/// Parsed IHDR data.
55#[derive(Debug, Clone, Copy)]
56struct PngIhdr {
57    width: u32,
58    height: u32,
59    bit_depth: u8,
60    color_type: PngColorType,
61    interlace: u8,
62}
63
64/// Decode a PNG image from raw file data.
65pub fn decode_png(data: &[u8]) -> Result<DecodedImage, ImageCodecError> {
66    // Verify signature
67    if data.len() < 8 || data[..8] != PNG_SIGNATURE {
68        return Err(ImageCodecError::InvalidSignature);
69    }
70
71    // Parse chunks
72    let mut pos: usize = 8;
73    let mut ihdr: Option<PngIhdr> = None;
74    let mut palette: Vec<(u8, u8, u8)> = Vec::new();
75    let mut trns: Vec<u8> = Vec::new();
76    let mut idat_data: Vec<u8> = Vec::new();
77    let mut _gamma: u32 = 0; // stored as gamma * 100000
78
79    while pos + 8 <= data.len() {
80        let chunk_len = read_be_u32(data, pos) as usize;
81        let chunk_type = &data[pos + 4..pos + 8];
82        let chunk_data_start = pos + 8;
83        let chunk_end = chunk_data_start + chunk_len;
84
85        if chunk_end + 4 > data.len() {
86            // Not enough data for chunk + CRC
87            break;
88        }
89
90        match chunk_type {
91            b"IHDR" => {
92                if chunk_len < 13 {
93                    return Err(ImageCodecError::CorruptData);
94                }
95                let cd = &data[chunk_data_start..chunk_end];
96                ihdr = Some(PngIhdr {
97                    width: read_be_u32(cd, 0),
98                    height: read_be_u32(cd, 4),
99                    bit_depth: cd[8],
100                    color_type: PngColorType::from_u8(cd[9])?,
101                    interlace: cd[12],
102                });
103            }
104            b"PLTE" => {
105                if !chunk_len.is_multiple_of(3) {
106                    return Err(ImageCodecError::CorruptData);
107                }
108                let cd = &data[chunk_data_start..chunk_end];
109                palette.clear();
110                let mut i = 0;
111                while i + 2 < cd.len() {
112                    palette.push((cd[i], cd[i + 1], cd[i + 2]));
113                    i += 3;
114                }
115            }
116            b"tRNS" => {
117                trns = data[chunk_data_start..chunk_end].to_vec();
118            }
119            b"gAMA" => {
120                if chunk_len >= 4 {
121                    _gamma = read_be_u32(data, chunk_data_start);
122                }
123            }
124            b"IDAT" => {
125                idat_data.extend_from_slice(&data[chunk_data_start..chunk_end]);
126            }
127            b"IEND" => {
128                break;
129            }
130            _ => {
131                // Skip unknown/ancillary chunks
132            }
133        }
134
135        pos = chunk_end + 4; // skip CRC
136    }
137
138    let ihdr = ihdr.ok_or(ImageCodecError::CorruptData)?;
139    if ihdr.width == 0 || ihdr.height == 0 {
140        return Err(ImageCodecError::InvalidDimensions);
141    }
142
143    // Validate bit depth for color type
144    match ihdr.color_type {
145        PngColorType::Grayscale => {
146            if !matches!(ihdr.bit_depth, 1 | 2 | 4 | 8 | 16) {
147                return Err(ImageCodecError::Unsupported);
148            }
149        }
150        PngColorType::Rgb | PngColorType::GrayscaleAlpha | PngColorType::Rgba => {
151            if !matches!(ihdr.bit_depth, 8 | 16) {
152                return Err(ImageCodecError::Unsupported);
153            }
154        }
155        PngColorType::Indexed => {
156            if !matches!(ihdr.bit_depth, 1 | 2 | 4 | 8) {
157                return Err(ImageCodecError::Unsupported);
158            }
159        }
160    }
161
162    // Decompress zlib-wrapped IDAT data
163    let raw_data = zlib_decompress(&idat_data)?;
164
165    // Unfilter and produce RGBA output
166    if ihdr.interlace == 1 {
167        decode_png_interlaced(&ihdr, &raw_data, &palette, &trns)
168    } else {
169        decode_png_non_interlaced(&ihdr, &raw_data, &palette, &trns)
170    }
171}
172
173/// Decode non-interlaced PNG scanlines.
174fn decode_png_non_interlaced(
175    ihdr: &PngIhdr,
176    raw: &[u8],
177    palette: &[(u8, u8, u8)],
178    trns: &[u8],
179) -> Result<DecodedImage, ImageCodecError> {
180    let w = ihdr.width as usize;
181    let h = ihdr.height as usize;
182    let channels = ihdr.color_type.channels();
183    let bits_per_pixel = channels * (ihdr.bit_depth as usize);
184    let bytes_per_row = (w * bits_per_pixel).div_ceil(8);
185    let bpp_bytes = bits_per_pixel.div_ceil(8); // filter byte stride
186
187    let mut img = DecodedImage::new(ihdr.width, ihdr.height);
188    let mut prev_row: Vec<u8> = alloc::vec![0u8; bytes_per_row];
189    let mut pos: usize = 0;
190
191    for y in 0..h {
192        if pos >= raw.len() {
193            return Err(ImageCodecError::TruncatedData);
194        }
195        let filter = raw[pos];
196        pos += 1;
197
198        if pos + bytes_per_row > raw.len() {
199            return Err(ImageCodecError::TruncatedData);
200        }
201
202        let mut current_row: Vec<u8> = raw[pos..pos + bytes_per_row].to_vec();
203        pos += bytes_per_row;
204
205        // Apply PNG filter reconstruction
206        png_unfilter(filter, &mut current_row, &prev_row, bpp_bytes)?;
207
208        // Convert scanline to RGBA pixels
209        png_scanline_to_rgba(
210            &current_row,
211            ihdr,
212            palette,
213            trns,
214            &mut img,
215            y as u32,
216            0,
217            w as u32,
218        );
219
220        prev_row = current_row;
221    }
222
223    Ok(img)
224}
225
226/// Decode Adam7 interlaced PNG.
227fn decode_png_interlaced(
228    ihdr: &PngIhdr,
229    raw: &[u8],
230    palette: &[(u8, u8, u8)],
231    trns: &[u8],
232) -> Result<DecodedImage, ImageCodecError> {
233    let w = ihdr.width as usize;
234    let h = ihdr.height as usize;
235    let channels = ihdr.color_type.channels();
236    let bits_per_pixel = channels * (ihdr.bit_depth as usize);
237
238    // Adam7 pass parameters: (x_start, y_start, x_step, y_step)
239    const ADAM7: [(usize, usize, usize, usize); 7] = [
240        (0, 0, 8, 8),
241        (4, 0, 8, 8),
242        (0, 4, 4, 8),
243        (2, 0, 4, 4),
244        (0, 2, 2, 4),
245        (1, 0, 2, 2),
246        (0, 1, 1, 2),
247    ];
248
249    let mut img = DecodedImage::new(ihdr.width, ihdr.height);
250    let mut pos: usize = 0;
251
252    for &(x_start, y_start, x_step, y_step) in &ADAM7 {
253        let pass_w = if x_start >= w {
254            0
255        } else {
256            (w - x_start).div_ceil(x_step)
257        };
258        let pass_h = if y_start >= h {
259            0
260        } else {
261            (h - y_start).div_ceil(y_step)
262        };
263
264        if pass_w == 0 || pass_h == 0 {
265            continue;
266        }
267
268        let bytes_per_row = (pass_w * bits_per_pixel).div_ceil(8);
269        let bpp_bytes = bits_per_pixel.div_ceil(8);
270        let mut prev_row: Vec<u8> = alloc::vec![0u8; bytes_per_row];
271
272        for pass_y in 0..pass_h {
273            if pos >= raw.len() {
274                return Err(ImageCodecError::TruncatedData);
275            }
276            let filter = raw[pos];
277            pos += 1;
278
279            if pos + bytes_per_row > raw.len() {
280                return Err(ImageCodecError::TruncatedData);
281            }
282
283            let mut current_row: Vec<u8> = raw[pos..pos + bytes_per_row].to_vec();
284            pos += bytes_per_row;
285
286            png_unfilter(filter, &mut current_row, &prev_row, bpp_bytes)?;
287
288            // Place pixels at correct interlaced positions
289            let out_y = y_start + pass_y * y_step;
290            for pass_x in 0..pass_w {
291                let out_x = x_start + pass_x * x_step;
292                let pixel = png_extract_pixel(&current_row, pass_x, ihdr, palette, trns);
293                img.set_pixel(
294                    out_x as u32,
295                    out_y as u32,
296                    pixel.0,
297                    pixel.1,
298                    pixel.2,
299                    pixel.3,
300                );
301            }
302
303            prev_row = current_row;
304        }
305    }
306
307    Ok(img)
308}
309
310/// PNG filter reconstruction (RFC 2083 Section 9).
311fn png_unfilter(
312    filter: u8,
313    current: &mut [u8],
314    prev: &[u8],
315    bpp: usize,
316) -> Result<(), ImageCodecError> {
317    let len = current.len();
318    match filter {
319        0 => {} // None
320        1 => {
321            // Sub
322            for i in bpp..len {
323                current[i] = current[i].wrapping_add(current[i - bpp]);
324            }
325        }
326        2 => {
327            // Up
328            for i in 0..len {
329                current[i] = current[i].wrapping_add(prev[i]);
330            }
331        }
332        3 => {
333            // Average
334            for i in 0..len {
335                let a = if i >= bpp { current[i - bpp] as u16 } else { 0 };
336                let b = prev[i] as u16;
337                current[i] = current[i].wrapping_add(((a + b) / 2) as u8);
338            }
339        }
340        4 => {
341            // Paeth
342            for i in 0..len {
343                let a = if i >= bpp { current[i - bpp] as i32 } else { 0 };
344                let b = prev[i] as i32;
345                let c = if i >= bpp { prev[i - bpp] as i32 } else { 0 };
346                current[i] = current[i].wrapping_add(paeth_predictor(a, b, c) as u8);
347            }
348        }
349        _ => return Err(ImageCodecError::CorruptData),
350    }
351    Ok(())
352}
353
354/// Paeth predictor function (integer only).
355#[inline]
356fn paeth_predictor(a: i32, b: i32, c: i32) -> i32 {
357    let p = a + b - c;
358    let pa = (p - a).abs();
359    let pb = (p - b).abs();
360    let pc = (p - c).abs();
361    if pa <= pb && pa <= pc {
362        a
363    } else if pb <= pc {
364        b
365    } else {
366        c
367    }
368}
369
370/// Convert a full PNG scanline to RGBA pixels in the output image.
371fn png_scanline_to_rgba(
372    row: &[u8],
373    ihdr: &PngIhdr,
374    palette: &[(u8, u8, u8)],
375    trns: &[u8],
376    img: &mut DecodedImage,
377    y: u32,
378    x_start: u32,
379    count: u32,
380) {
381    for x in 0..count {
382        let pixel = png_extract_pixel(row, x as usize, ihdr, palette, trns);
383        img.set_pixel(x_start + x, y, pixel.0, pixel.1, pixel.2, pixel.3);
384    }
385}
386
387/// Extract a single pixel from a PNG scanline, returning (R, G, B, A).
388fn png_extract_pixel(
389    row: &[u8],
390    x: usize,
391    ihdr: &PngIhdr,
392    palette: &[(u8, u8, u8)],
393    trns: &[u8],
394) -> (u8, u8, u8, u8) {
395    let bd = ihdr.bit_depth as usize;
396
397    match ihdr.color_type {
398        PngColorType::Grayscale => {
399            let v = extract_sample(row, x, bd);
400            let v8 = scale_to_8bit(v, bd);
401            let a = if trns.len() >= 2 {
402                let trns_val = read_be_u16(trns, 0) as usize;
403                if v == trns_val {
404                    0
405                } else {
406                    255
407                }
408            } else {
409                255
410            };
411            (v8, v8, v8, a)
412        }
413        PngColorType::Rgb => {
414            let bytes_per_sample = if bd == 16 { 2 } else { 1 };
415            let off = x * 3 * bytes_per_sample;
416            let (r, g, b) = if bd == 16 {
417                if off + 5 < row.len() {
418                    (row[off], row[off + 2], row[off + 4])
419                } else {
420                    (0, 0, 0)
421                }
422            } else if off + 2 < row.len() {
423                (row[off], row[off + 1], row[off + 2])
424            } else {
425                (0, 0, 0)
426            };
427            let a = if trns.len() >= 6 {
428                let tr = read_be_u16(trns, 0);
429                let tg = read_be_u16(trns, 2);
430                let tb = read_be_u16(trns, 4);
431                let (cr, cg, cb) = if bd == 16 {
432                    (
433                        read_be_u16(row, off),
434                        read_be_u16(row, off + 2),
435                        read_be_u16(row, off + 4),
436                    )
437                } else {
438                    (r as u16, g as u16, b as u16)
439                };
440                if cr == tr && cg == tg && cb == tb {
441                    0
442                } else {
443                    255
444                }
445            } else {
446                255
447            };
448            (r, g, b, a)
449        }
450        PngColorType::Indexed => {
451            let idx = extract_sample(row, x, bd);
452            if idx < palette.len() {
453                let (r, g, b) = palette[idx];
454                let a = if idx < trns.len() { trns[idx] } else { 255 };
455                (r, g, b, a)
456            } else {
457                (0, 0, 0, 255)
458            }
459        }
460        PngColorType::GrayscaleAlpha => {
461            let bytes_per_sample = if bd == 16 { 2 } else { 1 };
462            let off = x * 2 * bytes_per_sample;
463            let (v, a) = if bd == 16 {
464                if off + 3 < row.len() {
465                    (row[off], row[off + 2])
466                } else {
467                    (0, 0)
468                }
469            } else if off + 1 < row.len() {
470                (row[off], row[off + 1])
471            } else {
472                (0, 0)
473            };
474            (v, v, v, a)
475        }
476        PngColorType::Rgba => {
477            let bytes_per_sample = if bd == 16 { 2 } else { 1 };
478            let off = x * 4 * bytes_per_sample;
479            if bd == 16 {
480                if off + 7 < row.len() {
481                    (row[off], row[off + 2], row[off + 4], row[off + 6])
482                } else {
483                    (0, 0, 0, 0)
484                }
485            } else if off + 3 < row.len() {
486                (row[off], row[off + 1], row[off + 2], row[off + 3])
487            } else {
488                (0, 0, 0, 0)
489            }
490        }
491    }
492}
493
494/// Extract a sub-byte sample value from packed scanline data.
495fn extract_sample(row: &[u8], index: usize, bit_depth: usize) -> usize {
496    match bit_depth {
497        1 => {
498            let byte_idx = index / 8;
499            let bit_idx = 7 - (index % 8);
500            if byte_idx < row.len() {
501                ((row[byte_idx] >> bit_idx) & 1) as usize
502            } else {
503                0
504            }
505        }
506        2 => {
507            let byte_idx = index / 4;
508            let shift = 6 - (index % 4) * 2;
509            if byte_idx < row.len() {
510                ((row[byte_idx] >> shift) & 3) as usize
511            } else {
512                0
513            }
514        }
515        4 => {
516            let byte_idx = index / 2;
517            let shift = if index.is_multiple_of(2) { 4 } else { 0 };
518            if byte_idx < row.len() {
519                ((row[byte_idx] >> shift) & 0xF) as usize
520            } else {
521                0
522            }
523        }
524        8 => {
525            if index < row.len() {
526                row[index] as usize
527            } else {
528                0
529            }
530        }
531        16 => {
532            let off = index * 2;
533            if off + 1 < row.len() {
534                read_be_u16(row, off) as usize
535            } else {
536                0
537            }
538        }
539        _ => 0,
540    }
541}
542
543/// Scale a sample value from its native bit depth to 8-bit.
544fn scale_to_8bit(val: usize, bit_depth: usize) -> u8 {
545    match bit_depth {
546        1 => {
547            if val != 0 {
548                255
549            } else {
550                0
551            }
552        }
553        2 => (val * 85) as u8,         // 0->0, 1->85, 2->170, 3->255
554        4 => ((val * 255) / 15) as u8, // scale 0-15 to 0-255
555        8 => val as u8,
556        16 => (val >> 8) as u8,
557        _ => val as u8,
558    }
559}
560
561// ============================================================================
562// DEFLATE / ZLIB DECOMPRESSION (RFC 1950 / 1951)
563// ============================================================================
564
565/// Decompress zlib-wrapped data (CMF + FLG + compressed blocks + Adler-32).
566fn zlib_decompress(data: &[u8]) -> Result<Vec<u8>, ImageCodecError> {
567    if data.len() < 6 {
568        return Err(ImageCodecError::TruncatedData);
569    }
570
571    let cmf = data[0];
572    let _flg = data[1];
573
574    // CMF: bits 0-3 = CM (must be 8 for deflate), bits 4-7 = CINFO
575    if cmf & 0x0F != 8 {
576        return Err(ImageCodecError::Unsupported);
577    }
578
579    // Verify CMF/FLG check
580    let check = (cmf as u16) * 256 + (_flg as u16);
581    if !check.is_multiple_of(31) {
582        return Err(ImageCodecError::CorruptData);
583    }
584
585    // Check FDICT flag (bit 5 of FLG) -- we don't support preset dictionaries
586    if _flg & 0x20 != 0 {
587        return Err(ImageCodecError::Unsupported);
588    }
589
590    // Decompress DEFLATE stream starting at offset 2
591    let compressed = &data[2..];
592    let output = deflate_decompress(compressed)?;
593
594    // Verify Adler-32 checksum (last 4 bytes of zlib stream)
595    if data.len() >= 6 {
596        let adler_offset = data.len() - 4;
597        let stored_adler = read_be_u32(data, adler_offset);
598        let computed_adler = adler32(&output);
599        if stored_adler != computed_adler {
600            // Some PNG encoders produce valid images with wrong checksums;
601            // we log but don't fail for robustness.
602            // return Err(ImageCodecError::ChecksumMismatch);
603        }
604    }
605
606    Ok(output)
607}
608
609/// Compute Adler-32 checksum.
610fn adler32(data: &[u8]) -> u32 {
611    let mut a: u32 = 1;
612    let mut b: u32 = 0;
613
614    for &byte in data {
615        a = (a + byte as u32) % 65521;
616        b = (b + a) % 65521;
617    }
618
619    (b << 16) | a
620}
621
622/// Bit reader for DEFLATE stream.
623struct BitReader<'a> {
624    data: &'a [u8],
625    byte_pos: usize,
626    bit_pos: u8, // 0-7, bits consumed in current byte
627}
628
629impl<'a> BitReader<'a> {
630    fn new(data: &'a [u8]) -> Self {
631        Self {
632            data,
633            byte_pos: 0,
634            bit_pos: 0,
635        }
636    }
637
638    /// Read `n` bits (up to 25), LSB first.
639    fn read_bits(&mut self, n: u8) -> Result<u32, ImageCodecError> {
640        let mut result: u32 = 0;
641        let mut bits_read: u8 = 0;
642
643        while bits_read < n {
644            if self.byte_pos >= self.data.len() {
645                return Err(ImageCodecError::TruncatedData);
646            }
647
648            let available = 8 - self.bit_pos;
649            let needed = n - bits_read;
650            let take = if available < needed {
651                available
652            } else {
653                needed
654            };
655            let mask = (1u32 << take) - 1;
656            let bits = ((self.data[self.byte_pos] >> self.bit_pos) as u32) & mask;
657            result |= bits << bits_read;
658            bits_read += take;
659            self.bit_pos += take;
660
661            if self.bit_pos >= 8 {
662                self.bit_pos = 0;
663                self.byte_pos += 1;
664            }
665        }
666
667        Ok(result)
668    }
669
670    /// Read a single bit.
671    fn read_bit(&mut self) -> Result<u32, ImageCodecError> {
672        self.read_bits(1)
673    }
674
675    /// Align to byte boundary.
676    fn align(&mut self) {
677        if self.bit_pos > 0 {
678            self.bit_pos = 0;
679            self.byte_pos += 1;
680        }
681    }
682
683    /// Read a byte (must be byte-aligned).
684    fn read_byte(&mut self) -> Result<u8, ImageCodecError> {
685        if self.byte_pos >= self.data.len() {
686            return Err(ImageCodecError::TruncatedData);
687        }
688        let b = self.data[self.byte_pos];
689        self.byte_pos += 1;
690        Ok(b)
691    }
692}
693
694/// DEFLATE decompression (RFC 1951).
695fn deflate_decompress(data: &[u8]) -> Result<Vec<u8>, ImageCodecError> {
696    let mut reader = BitReader::new(data);
697    let mut output: Vec<u8> = Vec::new();
698
699    loop {
700        let bfinal = reader.read_bit()?;
701        let btype = reader.read_bits(2)?;
702
703        match btype {
704            0 => {
705                // Stored (uncompressed) block
706                reader.align();
707                let len = reader.read_byte()? as u16 | ((reader.read_byte()? as u16) << 8);
708                let _nlen = reader.read_byte()? as u16 | ((reader.read_byte()? as u16) << 8);
709                for _ in 0..len {
710                    output.push(reader.read_byte()?);
711                }
712            }
713            1 => {
714                // Fixed Huffman codes
715                let (lit_tree, dist_tree) = build_fixed_huffman_trees();
716                inflate_block(&mut reader, &lit_tree, &dist_tree, &mut output)?;
717            }
718            2 => {
719                // Dynamic Huffman codes
720                let (lit_tree, dist_tree) = decode_dynamic_huffman(&mut reader)?;
721                inflate_block(&mut reader, &lit_tree, &dist_tree, &mut output)?;
722            }
723            _ => return Err(ImageCodecError::CorruptData),
724        }
725
726        if bfinal != 0 {
727            break;
728        }
729    }
730
731    Ok(output)
732}
733
734/// A Huffman tree for DEFLATE decoding, stored as a lookup table.
735/// Maximum code length in DEFLATE is 15 bits.
736struct HuffmanTree {
737    /// For each code length, the number of codes and starting values.
738    /// Stored as (min_code, symbols) per bit length.
739    /// We use a simple linear decode approach.
740    counts: [u16; 16],
741    symbols: Vec<u16>,
742    min_codes: [u32; 16],
743    max_codes: [i32; 16],
744    offsets: [u16; 16],
745}
746
747impl HuffmanTree {
748    /// Build a Huffman tree from a list of code lengths.
749    fn from_lengths(lengths: &[u8]) -> Result<Self, ImageCodecError> {
750        let mut counts = [0u16; 16];
751        let mut max_len: usize = 0;
752
753        // Count occurrences of each code length
754        for &len in lengths {
755            let l = len as usize;
756            if l > 0 {
757                if l > 15 {
758                    return Err(ImageCodecError::InvalidHuffmanTable);
759                }
760                counts[l] += 1;
761                if l > max_len {
762                    max_len = l;
763                }
764            }
765        }
766
767        // Compute starting codes for each length
768        let mut code: u32 = 0;
769        let mut next_code = [0u32; 16];
770        let mut min_codes = [0u32; 16];
771        let mut max_codes = [-1i32; 16];
772        let mut offsets = [0u16; 16];
773
774        let mut offset: u16 = 0;
775        for bits in 1..=max_len {
776            code = (code + counts[bits - 1] as u32) << 1;
777            next_code[bits] = code;
778            min_codes[bits] = code;
779            offsets[bits] = offset;
780            if counts[bits] > 0 {
781                max_codes[bits] = (code + counts[bits] as u32 - 1) as i32;
782            }
783            offset += counts[bits];
784        }
785
786        // Assign symbols
787        let total_symbols = offset as usize;
788        let mut symbols = alloc::vec![0u16; total_symbols];
789        let mut symbol_idx = [0u16; 16];
790        symbol_idx[1..16].copy_from_slice(&offsets[1..16]);
791
792        for (sym, &len) in lengths.iter().enumerate() {
793            let l = len as usize;
794            if l > 0 && l < 16 {
795                let idx = symbol_idx[l] as usize;
796                if idx < symbols.len() {
797                    symbols[idx] = sym as u16;
798                    symbol_idx[l] += 1;
799                }
800            }
801        }
802
803        Ok(Self {
804            counts,
805            symbols,
806            min_codes,
807            max_codes,
808            offsets,
809        })
810    }
811
812    /// Decode one symbol from the bit stream.
813    fn decode(&self, reader: &mut BitReader) -> Result<u16, ImageCodecError> {
814        let mut code: u32 = 0;
815
816        for bits in 1..16u8 {
817            code = (code << 1) | reader.read_bit()?;
818            let b = bits as usize;
819            if self.max_codes[b] >= 0 && code <= self.max_codes[b] as u32 {
820                let idx = self.offsets[b] as usize + (code - self.min_codes[b]) as usize;
821                if idx < self.symbols.len() {
822                    return Ok(self.symbols[idx]);
823                }
824            }
825        }
826
827        Err(ImageCodecError::InvalidHuffmanTable)
828    }
829}
830
831/// Build the fixed Huffman trees for DEFLATE block type 1.
832fn build_fixed_huffman_trees() -> (HuffmanTree, HuffmanTree) {
833    // Literal/length: 0-143 => 8 bits, 144-255 => 9 bits, 256-279 => 7 bits,
834    // 280-287 => 8 bits
835    let mut lit_lengths = [0u8; 288];
836    lit_lengths[0..=143].fill(8);
837    lit_lengths[144..=255].fill(9);
838    lit_lengths[256..=279].fill(7);
839    lit_lengths[280..=287].fill(8);
840
841    // Distance: all 32 codes are 5 bits
842    let dist_lengths = [5u8; 32];
843
844    (
845        HuffmanTree::from_lengths(&lit_lengths).unwrap_or_else(|_| HuffmanTree {
846            counts: [0; 16],
847            symbols: Vec::new(),
848            min_codes: [0; 16],
849            max_codes: [-1; 16],
850            offsets: [0; 16],
851        }),
852        HuffmanTree::from_lengths(&dist_lengths).unwrap_or_else(|_| HuffmanTree {
853            counts: [0; 16],
854            symbols: Vec::new(),
855            min_codes: [0; 16],
856            max_codes: [-1; 16],
857            offsets: [0; 16],
858        }),
859    )
860}
861
862/// Decode dynamic Huffman trees from DEFLATE block type 2 header.
863fn decode_dynamic_huffman(
864    reader: &mut BitReader,
865) -> Result<(HuffmanTree, HuffmanTree), ImageCodecError> {
866    let hlit = reader.read_bits(5)? as usize + 257;
867    let hdist = reader.read_bits(5)? as usize + 1;
868    let hclen = reader.read_bits(4)? as usize + 4;
869
870    // Code length alphabet order
871    const CL_ORDER: [usize; 19] = [
872        16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
873    ];
874
875    let mut cl_lengths = [0u8; 19];
876    for i in 0..hclen {
877        cl_lengths[CL_ORDER[i]] = reader.read_bits(3)? as u8;
878    }
879
880    let cl_tree = HuffmanTree::from_lengths(&cl_lengths)?;
881
882    // Decode literal/length + distance code lengths
883    let total = hlit + hdist;
884    let mut lengths: Vec<u8> = Vec::with_capacity(total);
885
886    while lengths.len() < total {
887        let sym = cl_tree.decode(reader)?;
888        match sym {
889            0..=15 => {
890                lengths.push(sym as u8);
891            }
892            16 => {
893                // Repeat previous length 3-6 times
894                let extra = reader.read_bits(2)? as usize + 3;
895                let prev = if let Some(&last) = lengths.last() {
896                    last
897                } else {
898                    0
899                };
900                for _ in 0..extra {
901                    if lengths.len() < total {
902                        lengths.push(prev);
903                    }
904                }
905            }
906            17 => {
907                // Repeat 0 for 3-10 times
908                let extra = reader.read_bits(3)? as usize + 3;
909                for _ in 0..extra {
910                    if lengths.len() < total {
911                        lengths.push(0);
912                    }
913                }
914            }
915            18 => {
916                // Repeat 0 for 11-138 times
917                let extra = reader.read_bits(7)? as usize + 11;
918                for _ in 0..extra {
919                    if lengths.len() < total {
920                        lengths.push(0);
921                    }
922                }
923            }
924            _ => return Err(ImageCodecError::InvalidHuffmanTable),
925        }
926    }
927
928    let lit_tree = HuffmanTree::from_lengths(&lengths[..hlit])?;
929    let dist_tree = HuffmanTree::from_lengths(&lengths[hlit..hlit + hdist])?;
930
931    Ok((lit_tree, dist_tree))
932}
933
934/// Length base values and extra bits for codes 257-285.
935const LENGTH_BASE: [u16; 29] = [
936    3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131,
937    163, 195, 227, 258,
938];
939const LENGTH_EXTRA: [u8; 29] = [
940    0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0,
941];
942
943/// Distance base values and extra bits for codes 0-29.
944const DIST_BASE: [u16; 30] = [
945    1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537,
946    2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
947];
948const DIST_EXTRA: [u8; 30] = [
949    0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13,
950    13,
951];
952
953/// Inflate one DEFLATE block using the given Huffman trees.
954fn inflate_block(
955    reader: &mut BitReader,
956    lit_tree: &HuffmanTree,
957    dist_tree: &HuffmanTree,
958    output: &mut Vec<u8>,
959) -> Result<(), ImageCodecError> {
960    loop {
961        let sym = lit_tree.decode(reader)?;
962
963        if sym < 256 {
964            // Literal byte
965            output.push(sym as u8);
966        } else if sym == 256 {
967            // End of block
968            break;
969        } else {
970            // Length/distance pair
971            let len_idx = (sym - 257) as usize;
972            if len_idx >= LENGTH_BASE.len() {
973                return Err(ImageCodecError::CorruptData);
974            }
975            let length =
976                LENGTH_BASE[len_idx] as usize + reader.read_bits(LENGTH_EXTRA[len_idx])? as usize;
977
978            let dist_sym = dist_tree.decode(reader)? as usize;
979            if dist_sym >= DIST_BASE.len() {
980                return Err(ImageCodecError::CorruptData);
981            }
982            let distance =
983                DIST_BASE[dist_sym] as usize + reader.read_bits(DIST_EXTRA[dist_sym])? as usize;
984
985            if distance > output.len() {
986                return Err(ImageCodecError::CorruptData);
987            }
988
989            // Copy from back-reference (byte-by-byte for overlapping copies)
990            let start = output.len() - distance;
991            for i in 0..length {
992                let b = output[start + (i % distance)];
993                output.push(b);
994            }
995        }
996    }
997
998    Ok(())
999}
1000
1001// ============================================================================
1002// TESTS
1003// ============================================================================
1004
1005#[cfg(test)]
1006mod tests {
1007    #[allow(unused_imports)]
1008    use alloc::vec;
1009
1010    use super::*;
1011
1012    #[test]
1013    fn test_png_signature_check() {
1014        let bad = [0u8; 8];
1015        assert_eq!(decode_png(&bad), Err(ImageCodecError::InvalidSignature));
1016    }
1017
1018    #[test]
1019    fn test_png_too_short() {
1020        let data = [137, 80, 78, 71];
1021        assert_eq!(decode_png(&data), Err(ImageCodecError::InvalidSignature));
1022    }
1023
1024    #[test]
1025    fn test_png_color_type_from_u8() {
1026        assert_eq!(PngColorType::from_u8(0), Ok(PngColorType::Grayscale));
1027        assert_eq!(PngColorType::from_u8(2), Ok(PngColorType::Rgb));
1028        assert_eq!(PngColorType::from_u8(3), Ok(PngColorType::Indexed));
1029        assert_eq!(PngColorType::from_u8(4), Ok(PngColorType::GrayscaleAlpha));
1030        assert_eq!(PngColorType::from_u8(6), Ok(PngColorType::Rgba));
1031        assert_eq!(PngColorType::from_u8(7), Err(ImageCodecError::Unsupported));
1032    }
1033
1034    #[test]
1035    fn test_png_color_type_channels() {
1036        assert_eq!(PngColorType::Grayscale.channels(), 1);
1037        assert_eq!(PngColorType::Rgb.channels(), 3);
1038        assert_eq!(PngColorType::Indexed.channels(), 1);
1039        assert_eq!(PngColorType::GrayscaleAlpha.channels(), 2);
1040        assert_eq!(PngColorType::Rgba.channels(), 4);
1041    }
1042
1043    #[test]
1044    fn test_png_unfilter_none() {
1045        let mut row = vec![1, 2, 3, 4];
1046        let prev = vec![0, 0, 0, 0];
1047        png_unfilter(0, &mut row, &prev, 1).unwrap();
1048        assert_eq!(row, vec![1, 2, 3, 4]);
1049    }
1050
1051    #[test]
1052    fn test_png_unfilter_sub() {
1053        let mut row = vec![1, 2, 3, 4];
1054        let prev = vec![0, 0, 0, 0];
1055        png_unfilter(1, &mut row, &prev, 1).unwrap();
1056        // Sub: each byte += previous byte in row
1057        // [1, 1+2=3, 3+3=6, 6+4=10]
1058        assert_eq!(row, vec![1, 3, 6, 10]);
1059    }
1060
1061    #[test]
1062    fn test_png_unfilter_up() {
1063        let mut row = vec![1, 2, 3, 4];
1064        let prev = vec![10, 20, 30, 40];
1065        png_unfilter(2, &mut row, &prev, 1).unwrap();
1066        assert_eq!(row, vec![11, 22, 33, 44]);
1067    }
1068
1069    #[test]
1070    fn test_png_unfilter_average() {
1071        let mut row = vec![0, 0, 0, 0];
1072        let prev = vec![10, 20, 30, 40];
1073        png_unfilter(3, &mut row, &prev, 1).unwrap();
1074        // Average: byte += floor((a + b) / 2) where a=left, b=above
1075        // [0+floor(0+10)/2=5, 0+floor(5+20)/2=12, 0+floor(12+30)/2=21,
1076        // 0+floor(21+40)/2=30]
1077        assert_eq!(row, vec![5, 12, 21, 30]);
1078    }
1079
1080    #[test]
1081    fn test_png_unfilter_paeth() {
1082        let mut row = vec![10, 20, 30, 40];
1083        let prev = vec![0, 0, 0, 0];
1084        png_unfilter(4, &mut row, &prev, 1).unwrap();
1085        // Paeth with all-zero prev: a=left, b=0, c=0 => paeth=a (except first where
1086        // a=0) [10+paeth(0,0,0)=10, 20+paeth(10,0,0)=30, 30+paeth(30,0,0)=60,
1087        // 40+paeth(60,0,0)=100]
1088        assert_eq!(row, vec![10, 30, 60, 100]);
1089    }
1090
1091    #[test]
1092    fn test_paeth_predictor_basic() {
1093        // When a=0, b=0, c=0 => p=0, pa=0, pb=0, pc=0 => returns a=0
1094        assert_eq!(paeth_predictor(0, 0, 0), 0);
1095        // When a=10, b=20, c=5 => p=25, pa=15, pb=5, pc=20 => returns b=20
1096        assert_eq!(paeth_predictor(10, 20, 5), 20);
1097    }
1098
1099    #[test]
1100    fn test_extract_sample_1bit() {
1101        let row = vec![0b10110100];
1102        assert_eq!(extract_sample(&row, 0, 1), 1); // bit 7
1103        assert_eq!(extract_sample(&row, 1, 1), 0); // bit 6
1104        assert_eq!(extract_sample(&row, 2, 1), 1); // bit 5
1105        assert_eq!(extract_sample(&row, 3, 1), 1); // bit 4
1106        assert_eq!(extract_sample(&row, 4, 1), 0); // bit 3
1107        assert_eq!(extract_sample(&row, 5, 1), 1); // bit 2
1108    }
1109
1110    #[test]
1111    fn test_extract_sample_4bit() {
1112        let row = vec![0xAB, 0xCD];
1113        assert_eq!(extract_sample(&row, 0, 4), 0xA);
1114        assert_eq!(extract_sample(&row, 1, 4), 0xB);
1115        assert_eq!(extract_sample(&row, 2, 4), 0xC);
1116        assert_eq!(extract_sample(&row, 3, 4), 0xD);
1117    }
1118
1119    #[test]
1120    fn test_scale_to_8bit() {
1121        assert_eq!(scale_to_8bit(0, 1), 0);
1122        assert_eq!(scale_to_8bit(1, 1), 255);
1123        assert_eq!(scale_to_8bit(0, 2), 0);
1124        assert_eq!(scale_to_8bit(3, 2), 255);
1125        assert_eq!(scale_to_8bit(128, 8), 128);
1126        assert_eq!(scale_to_8bit(0xFF00, 16), 255);
1127    }
1128
1129    #[test]
1130    fn test_png_unfilter_invalid_filter() {
1131        let mut row = vec![1, 2, 3];
1132        let prev = vec![0, 0, 0];
1133        assert_eq!(
1134            png_unfilter(5, &mut row, &prev, 1),
1135            Err(ImageCodecError::CorruptData)
1136        );
1137    }
1138
1139    // -----------------------------------------------------------------------
1140    // Adler-32 / zlib tests
1141    // -----------------------------------------------------------------------
1142
1143    #[test]
1144    fn test_adler32_empty() {
1145        assert_eq!(adler32(&[]), 0x00000001);
1146    }
1147
1148    #[test]
1149    fn test_adler32_known() {
1150        // adler32("Wikipedia") = 0x11E60398
1151        let data = b"Wikipedia";
1152        assert_eq!(adler32(data), 0x11E60398);
1153    }
1154
1155    // -----------------------------------------------------------------------
1156    // DEFLATE / Huffman tests
1157    // -----------------------------------------------------------------------
1158
1159    #[test]
1160    fn test_huffman_tree_from_lengths() {
1161        // Simple tree: symbols 0,1 with lengths 1,1 => codes 0,1
1162        let lengths = [1u8, 1];
1163        let tree = HuffmanTree::from_lengths(&lengths).unwrap();
1164        assert_eq!(tree.counts[1], 2);
1165        assert_eq!(tree.symbols.len(), 2);
1166    }
1167
1168    #[test]
1169    fn test_deflate_stored_block() {
1170        // Construct a minimal stored block: BFINAL=1, BTYPE=00, LEN=3, NLEN=~3,
1171        // data="abc"
1172        let mut block = Vec::new();
1173        block.push(0x01); // BFINAL=1, BTYPE=00 (stored) => bits: 1 00 = 0b001
1174                          // LEN = 3 (LE)
1175        block.push(0x03);
1176        block.push(0x00);
1177        // NLEN = !3 = 0xFFFC (LE)
1178        block.push(0xFC);
1179        block.push(0xFF);
1180        // Data
1181        block.push(b'a');
1182        block.push(b'b');
1183        block.push(b'c');
1184
1185        let result = deflate_decompress(&block).unwrap();
1186        assert_eq!(result, b"abc");
1187    }
1188
1189    #[test]
1190    fn test_bit_reader_basic() {
1191        let data = [0b10110100, 0xFF];
1192        let mut reader = BitReader::new(&data);
1193        // Read 4 bits LSB first from 0b10110100 => bits: 0,0,1,0 => 0b0100 = 4
1194        let v = reader.read_bits(4).unwrap();
1195        assert_eq!(v, 0b0100);
1196    }
1197}