⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/media/image_codecs/
jpeg.rs

1//! JPEG decoder (Baseline DCT, SOF0).
2//!
3//! Supports Huffman entropy coding, integer IDCT, YCbCr-to-RGB via
4//! fixed-point arithmetic, chroma subsampling 4:4:4/4:2:2/4:2:0,
5//! and restart intervals.
6
7#![allow(dead_code)]
8
9use alloc::vec::Vec;
10
11use super::{clamp_u8, read_be_u16, DecodedImage, ImageCodecError};
12
13// ============================================================================
14// JPEG DECODER (Baseline DCT)
15// ============================================================================
16
17/// JPEG marker constants.
18const JPEG_SOI: u16 = 0xFFD8;
19const JPEG_EOI: u16 = 0xFFD9;
20const JPEG_SOF0: u16 = 0xFFC0; // Baseline DCT
21const JPEG_DHT: u16 = 0xFFC4;
22const JPEG_DQT: u16 = 0xFFDB;
23const JPEG_DRI: u16 = 0xFFDD;
24const JPEG_SOS: u16 = 0xFFDA;
25const JPEG_RST0: u16 = 0xFFD0;
26// RST1-RST7: 0xFFD1-0xFFD7
27
28/// JPEG component info.
29#[derive(Debug, Clone, Copy, Default)]
30struct JpegComponent {
31    id: u8,
32    h_samp: u8,      // horizontal sampling factor
33    v_samp: u8,      // vertical sampling factor
34    quant_table: u8, // quantization table index
35    dc_table: u8,    // DC Huffman table index
36    ac_table: u8,    // AC Huffman table index
37    dc_pred: i32,    // DC prediction value
38}
39
40/// JPEG Huffman table (for entropy decoding).
41#[derive(Clone)]
42struct JpegHuffTable {
43    /// Number of codes for each bit length (1-16).
44    counts: [u8; 17],
45    /// Symbol values, in order.
46    symbols: Vec<u8>,
47    /// Lookup: max code value for each bit length.
48    max_code: [i32; 18],
49    /// Lookup: value offset for each bit length.
50    val_offset: [i32; 18],
51}
52
53impl Default for JpegHuffTable {
54    fn default() -> Self {
55        Self {
56            counts: [0; 17],
57            symbols: Vec::new(),
58            max_code: [-1; 18],
59            val_offset: [0; 18],
60        }
61    }
62}
63
64impl JpegHuffTable {
65    /// Build lookup tables from counts and symbols.
66    fn build(&mut self) {
67        let mut code: i32 = 0;
68        let mut si: i32 = 0;
69
70        for i in 1..=16 {
71            if self.counts[i] > 0 {
72                self.val_offset[i] = si - code;
73                code += self.counts[i] as i32;
74                self.max_code[i] = code - 1;
75            } else {
76                self.max_code[i] = -1;
77            }
78            si += self.counts[i] as i32;
79            code <<= 1;
80        }
81    }
82}
83
84/// JPEG bit reader that handles byte-stuffing (0xFF 0x00 -> 0xFF).
85struct JpegBitReader<'a> {
86    data: &'a [u8],
87    pos: usize,
88    bit_buf: u32,
89    bits_left: u8,
90}
91
92impl<'a> JpegBitReader<'a> {
93    fn new(data: &'a [u8], start: usize) -> Self {
94        Self {
95            data,
96            pos: start,
97            bit_buf: 0,
98            bits_left: 0,
99        }
100    }
101
102    /// Read the next byte, handling JPEG byte stuffing.
103    fn next_byte(&mut self) -> Result<u8, ImageCodecError> {
104        if self.pos >= self.data.len() {
105            return Err(ImageCodecError::TruncatedData);
106        }
107        let b = self.data[self.pos];
108        self.pos += 1;
109
110        if b == 0xFF {
111            if self.pos >= self.data.len() {
112                return Err(ImageCodecError::TruncatedData);
113            }
114            let next = self.data[self.pos];
115            if next == 0x00 {
116                // Byte-stuffed 0xFF
117                self.pos += 1;
118                Ok(0xFF)
119            } else if (0xD0..=0xD7).contains(&next) {
120                // RST marker -- skip it and return next real byte
121                self.pos += 1;
122                self.next_byte()
123            } else {
124                // Other marker -- signal end of scan
125                self.pos -= 1;
126                Err(ImageCodecError::TruncatedData)
127            }
128        } else {
129            Ok(b)
130        }
131    }
132
133    /// Fill the bit buffer.
134    fn fill_bits(&mut self) -> Result<(), ImageCodecError> {
135        while self.bits_left <= 24 {
136            let b = self.next_byte()?;
137            self.bit_buf |= (b as u32) << (24 - self.bits_left);
138            self.bits_left += 8;
139        }
140        Ok(())
141    }
142
143    /// Read `n` bits from MSB.
144    fn read_bits(&mut self, n: u8) -> Result<i32, ImageCodecError> {
145        if n == 0 {
146            return Ok(0);
147        }
148        while self.bits_left < n {
149            let b = self.next_byte()?;
150            self.bit_buf |= (b as u32) << (24 - self.bits_left);
151            self.bits_left += 8;
152        }
153        let val = (self.bit_buf >> (32 - n)) as i32;
154        self.bit_buf <<= n;
155        self.bits_left -= n;
156        Ok(val)
157    }
158
159    /// Decode one Huffman symbol.
160    fn decode_huff(&mut self, table: &JpegHuffTable) -> Result<u8, ImageCodecError> {
161        // Ensure enough bits in buffer
162        while self.bits_left < 16 {
163            match self.next_byte() {
164                Ok(b) => {
165                    self.bit_buf |= (b as u32) << (24 - self.bits_left);
166                    self.bits_left += 8;
167                }
168                Err(_) => break,
169            }
170        }
171
172        let mut code: i32 = 0;
173        for bits in 1..=16u8 {
174            code = (code << 1) | ((self.bit_buf >> 31) as i32);
175            self.bit_buf <<= 1;
176            self.bits_left = self.bits_left.saturating_sub(1);
177
178            if code <= table.max_code[bits as usize] {
179                let idx = (code + table.val_offset[bits as usize]) as usize;
180                if idx < table.symbols.len() {
181                    return Ok(table.symbols[idx]);
182                }
183            }
184        }
185
186        Err(ImageCodecError::InvalidHuffmanTable)
187    }
188
189    /// Receive and extend a value category.
190    fn receive_extend(&mut self, nbits: u8) -> Result<i32, ImageCodecError> {
191        if nbits == 0 {
192            return Ok(0);
193        }
194        let val = self.read_bits(nbits)?;
195        // Sign extension: if MSB is 0, value is negative
196        if val < (1 << (nbits - 1)) {
197            Ok(val - (1 << nbits) + 1)
198        } else {
199            Ok(val)
200        }
201    }
202
203    /// Reset bit reader state (for restart markers).
204    fn reset_bits(&mut self) {
205        self.bit_buf = 0;
206        self.bits_left = 0;
207    }
208}
209
210/// Decode a JPEG image (baseline DCT only).
211pub fn decode_jpeg(data: &[u8]) -> Result<DecodedImage, ImageCodecError> {
212    if data.len() < 4 {
213        return Err(ImageCodecError::TruncatedData);
214    }
215
216    // Check SOI marker
217    if data[0] != 0xFF || data[1] != 0xD8 {
218        return Err(ImageCodecError::InvalidSignature);
219    }
220
221    // Parse markers
222    let mut pos: usize = 2;
223    let mut width: u32 = 0;
224    let mut height: u32 = 0;
225    let mut num_components: usize = 0;
226    let mut components = [JpegComponent::default(); 4];
227    let mut max_h_samp: u8 = 1;
228    let mut max_v_samp: u8 = 1;
229    let mut quant_tables = [[0i32; 64]; 4];
230    let mut dc_tables = [JpegHuffTable::default(), JpegHuffTable::default()];
231    let mut ac_tables = [JpegHuffTable::default(), JpegHuffTable::default()];
232    let mut restart_interval: u16 = 0;
233    let mut scan_start: usize = 0;
234
235    while pos + 1 < data.len() {
236        if data[pos] != 0xFF {
237            pos += 1;
238            continue;
239        }
240
241        let marker = ((data[pos] as u16) << 8) | data[pos + 1] as u16;
242        pos += 2;
243
244        match marker {
245            JPEG_EOI => break,
246            JPEG_SOF0 => {
247                // Baseline DCT frame header
248                if pos + 2 > data.len() {
249                    return Err(ImageCodecError::TruncatedData);
250                }
251                let seg_len = read_be_u16(data, pos) as usize;
252                if pos + seg_len > data.len() || seg_len < 8 {
253                    return Err(ImageCodecError::TruncatedData);
254                }
255                let _precision = data[pos + 2]; // must be 8 for baseline
256                height = read_be_u16(data, pos + 3) as u32;
257                width = read_be_u16(data, pos + 5) as u32;
258                num_components = data[pos + 7] as usize;
259
260                if num_components > 4 || num_components == 0 {
261                    return Err(ImageCodecError::Unsupported);
262                }
263
264                for (i, comp) in components.iter_mut().enumerate().take(num_components) {
265                    let off = pos + 8 + i * 3;
266                    if off + 2 >= data.len() {
267                        return Err(ImageCodecError::TruncatedData);
268                    }
269                    comp.id = data[off];
270                    comp.h_samp = (data[off + 1] >> 4) & 0x0F;
271                    comp.v_samp = data[off + 1] & 0x0F;
272                    comp.quant_table = data[off + 2];
273
274                    if comp.h_samp > max_h_samp {
275                        max_h_samp = comp.h_samp;
276                    }
277                    if comp.v_samp > max_v_samp {
278                        max_v_samp = comp.v_samp;
279                    }
280                }
281
282                pos += seg_len;
283            }
284            JPEG_DHT => {
285                // Huffman table definition
286                if pos + 2 > data.len() {
287                    return Err(ImageCodecError::TruncatedData);
288                }
289                let seg_len = read_be_u16(data, pos) as usize;
290                let seg_end = pos + seg_len;
291                if seg_end > data.len() {
292                    return Err(ImageCodecError::TruncatedData);
293                }
294                let mut p = pos + 2;
295
296                while p < seg_end {
297                    if p >= data.len() {
298                        break;
299                    }
300                    let info = data[p];
301                    p += 1;
302                    let table_class = (info >> 4) & 0x0F; // 0=DC, 1=AC
303                    let table_id = (info & 0x0F) as usize;
304
305                    if table_id > 1 {
306                        return Err(ImageCodecError::InvalidHuffmanTable);
307                    }
308
309                    let mut table = JpegHuffTable::default();
310                    let mut total_symbols: usize = 0;
311                    for i in 1..=16 {
312                        if p >= data.len() {
313                            return Err(ImageCodecError::TruncatedData);
314                        }
315                        table.counts[i] = data[p];
316                        total_symbols += data[p] as usize;
317                        p += 1;
318                    }
319
320                    table.symbols = Vec::with_capacity(total_symbols);
321                    for _ in 0..total_symbols {
322                        if p >= data.len() {
323                            return Err(ImageCodecError::TruncatedData);
324                        }
325                        table.symbols.push(data[p]);
326                        p += 1;
327                    }
328
329                    table.build();
330
331                    if table_class == 0 {
332                        dc_tables[table_id] = table;
333                    } else {
334                        ac_tables[table_id] = table;
335                    }
336                }
337
338                pos = seg_end;
339            }
340            JPEG_DQT => {
341                // Quantization table
342                if pos + 2 > data.len() {
343                    return Err(ImageCodecError::TruncatedData);
344                }
345                let seg_len = read_be_u16(data, pos) as usize;
346                let seg_end = pos + seg_len;
347                if seg_end > data.len() {
348                    return Err(ImageCodecError::TruncatedData);
349                }
350                let mut p = pos + 2;
351
352                while p < seg_end {
353                    if p >= data.len() {
354                        break;
355                    }
356                    let info = data[p];
357                    p += 1;
358                    let precision = (info >> 4) & 0x0F; // 0=8-bit, 1=16-bit
359                    let table_id = (info & 0x0F) as usize;
360
361                    if table_id > 3 {
362                        return Err(ImageCodecError::InvalidQuantTable);
363                    }
364
365                    for qt in quant_tables[table_id].iter_mut() {
366                        if precision == 0 {
367                            if p >= data.len() {
368                                return Err(ImageCodecError::TruncatedData);
369                            }
370                            *qt = data[p] as i32;
371                            p += 1;
372                        } else {
373                            if p + 1 >= data.len() {
374                                return Err(ImageCodecError::TruncatedData);
375                            }
376                            *qt = read_be_u16(data, p) as i32;
377                            p += 2;
378                        }
379                    }
380                }
381
382                pos = seg_end;
383            }
384            JPEG_DRI => {
385                // Restart interval
386                if pos + 2 > data.len() {
387                    return Err(ImageCodecError::TruncatedData);
388                }
389                let _seg_len = read_be_u16(data, pos);
390                if pos + 4 > data.len() {
391                    return Err(ImageCodecError::TruncatedData);
392                }
393                restart_interval = read_be_u16(data, pos + 2);
394                pos += _seg_len as usize;
395            }
396            JPEG_SOS => {
397                // Start of scan
398                if pos + 2 > data.len() {
399                    return Err(ImageCodecError::TruncatedData);
400                }
401                let seg_len = read_be_u16(data, pos) as usize;
402                if pos + seg_len > data.len() {
403                    return Err(ImageCodecError::TruncatedData);
404                }
405
406                let ns = data[pos + 2] as usize;
407                for i in 0..ns {
408                    let off = pos + 3 + i * 2;
409                    if off + 1 >= data.len() {
410                        return Err(ImageCodecError::TruncatedData);
411                    }
412                    let comp_id = data[off];
413                    let td_ta = data[off + 1];
414
415                    // Find matching component
416                    for comp in components.iter_mut().take(num_components) {
417                        if comp.id == comp_id {
418                            comp.dc_table = (td_ta >> 4) & 0x0F;
419                            comp.ac_table = td_ta & 0x0F;
420                        }
421                    }
422                }
423
424                scan_start = pos + seg_len;
425                break;
426            }
427            _ => {
428                // Skip unknown marker segment
429                if pos + 2 > data.len() {
430                    break;
431                }
432                if marker >= 0xFFC0 && marker != 0xFF00 {
433                    let seg_len = read_be_u16(data, pos) as usize;
434                    pos += seg_len;
435                }
436            }
437        }
438    }
439
440    if width == 0 || height == 0 {
441        return Err(ImageCodecError::InvalidDimensions);
442    }
443
444    // MCU dimensions
445    let mcu_w = (max_h_samp as u32) * 8;
446    let mcu_h = (max_v_samp as u32) * 8;
447    let mcus_x = width.div_ceil(mcu_w);
448    let mcus_y = height.div_ceil(mcu_h);
449
450    // Allocate component planes
451    let mut comp_data: Vec<Vec<i32>> = Vec::new();
452    for comp in components.iter().take(num_components) {
453        let cw = mcus_x as usize * (comp.h_samp as usize) * 8;
454        let ch = mcus_y as usize * (comp.v_samp as usize) * 8;
455        comp_data.push(alloc::vec![0i32; cw * ch]);
456    }
457
458    // Entropy decode MCUs
459    let mut reader = JpegBitReader::new(data, scan_start);
460    let mut mcu_count: u32 = 0;
461
462    // Reset DC predictors
463    for c in components.iter_mut().take(num_components) {
464        c.dc_pred = 0;
465    }
466
467    for mcu_y in 0..mcus_y {
468        for mcu_x in 0..mcus_x {
469            // Check restart interval
470            if restart_interval > 0
471                && mcu_count > 0
472                && mcu_count.is_multiple_of(restart_interval as u32)
473            {
474                // Reset DC predictors and bit reader
475                for c in components.iter_mut().take(num_components) {
476                    c.dc_pred = 0;
477                }
478                reader.reset_bits();
479                // Skip to next RST marker
480                while reader.pos < data.len() {
481                    if data[reader.pos] == 0xFF && reader.pos + 1 < data.len() {
482                        let m = data[reader.pos + 1];
483                        if (0xD0..=0xD7).contains(&m) {
484                            reader.pos += 2;
485                            break;
486                        }
487                    }
488                    reader.pos += 1;
489                }
490            }
491
492            // Decode each component's blocks in this MCU
493            for ci in 0..num_components {
494                let h_samp = components[ci].h_samp as usize;
495                let v_samp = components[ci].v_samp as usize;
496                let qt_idx = components[ci].quant_table as usize;
497                let dc_idx = components[ci].dc_table as usize;
498                let ac_idx = components[ci].ac_table as usize;
499                let comp_w = mcus_x as usize * h_samp * 8;
500
501                for sv in 0..v_samp {
502                    for sh in 0..h_samp {
503                        // Decode one 8x8 block
504                        let mut block = [0i32; 64];
505
506                        // DC coefficient
507                        let dc_sym = reader.decode_huff(&dc_tables[dc_idx])?;
508                        let dc_diff = reader.receive_extend(dc_sym)?;
509                        components[ci].dc_pred += dc_diff;
510                        block[0] = components[ci].dc_pred;
511
512                        // AC coefficients (zig-zag order)
513                        let mut k: usize = 1;
514                        while k < 64 {
515                            let ac_sym = reader.decode_huff(&ac_tables[ac_idx])?;
516                            if ac_sym == 0x00 {
517                                // End of block
518                                break;
519                            }
520                            let run = (ac_sym >> 4) & 0x0F;
521                            let size = ac_sym & 0x0F;
522
523                            if size == 0 && run == 0x0F {
524                                // ZRL: skip 16 zeros
525                                k += 16;
526                                continue;
527                            }
528
529                            k += run as usize;
530                            if k >= 64 {
531                                break;
532                            }
533
534                            block[ZIGZAG[k] as usize] = reader.receive_extend(size)?;
535                            k += 1;
536                        }
537
538                        // Dequantize
539                        if qt_idx < 4 {
540                            for i in 0..64 {
541                                block[i] = block[i]
542                                    .checked_mul(quant_tables[qt_idx][i])
543                                    .unwrap_or(block[i]);
544                            }
545                        }
546
547                        // IDCT
548                        idct_integer(&mut block);
549
550                        // Store block into component plane
551                        let bx = mcu_x as usize * h_samp * 8 + sh * 8;
552                        let by = mcu_y as usize * v_samp * 8 + sv * 8;
553
554                        for row in 0..8 {
555                            for col in 0..8 {
556                                let px = bx + col;
557                                let py = by + row;
558                                let idx = py * comp_w + px;
559                                if idx < comp_data[ci].len() {
560                                    // Level shift: add 128 to bring from [-128,127] to [0,255]
561                                    comp_data[ci][idx] = block[row * 8 + col] + 128;
562                                }
563                            }
564                        }
565                    }
566                }
567            }
568
569            mcu_count += 1;
570        }
571    }
572
573    // Convert component planes to RGB output
574    let mut img = DecodedImage::new(width, height);
575
576    if num_components == 1 {
577        // Grayscale
578        let comp_w = mcus_x as usize * 8;
579        for y in 0..height {
580            for x in 0..width {
581                let idx = y as usize * comp_w + x as usize;
582                if idx < comp_data[0].len() {
583                    let v = clamp_u8(comp_data[0][idx]);
584                    img.set_pixel(x, y, v, v, v, 255);
585                }
586            }
587        }
588    } else if num_components >= 3 {
589        // YCbCr to RGB conversion
590        let y_w = mcus_x as usize * (components[0].h_samp as usize) * 8;
591        let cb_w = mcus_x as usize * (components[1].h_samp as usize) * 8;
592        let cr_w = mcus_x as usize * (components[2].h_samp as usize) * 8;
593
594        let h0 = components[0].h_samp as usize;
595        let v0 = components[0].v_samp as usize;
596        let h1 = components[1].h_samp as usize;
597        let v1 = components[1].v_samp as usize;
598
599        for py in 0..height {
600            for px in 0..width {
601                // Y sample position
602                let y_idx = py as usize * y_w + px as usize;
603
604                // Cb/Cr sample position (with subsampling)
605                let cx = if h0 > 0 { (px as usize * h1) / h0 } else { 0 };
606                let cy = if v0 > 0 { (py as usize * v1) / v0 } else { 0 };
607                let cb_idx = cy * cb_w + cx;
608                let cr_idx = cy * cr_w + cx;
609
610                let yv = if y_idx < comp_data[0].len() {
611                    comp_data[0][y_idx]
612                } else {
613                    128
614                };
615                let cb = if cb_idx < comp_data[1].len() {
616                    comp_data[1][cb_idx] - 128
617                } else {
618                    0
619                };
620                let cr = if cr_idx < comp_data[2].len() {
621                    comp_data[2][cr_idx] - 128
622                } else {
623                    0
624                };
625
626                // Fixed-point YCbCr to RGB (BT.601)
627                // R = Y + 1.402 * Cr        => Y + (359 * Cr) >> 8
628                // G = Y - 0.344 * Cb - 0.714 * Cr => Y - (88 * Cb + 183 * Cr) >> 8
629                // B = Y + 1.772 * Cb         => Y + (454 * Cb) >> 8
630                let r = clamp_u8(yv + ((359 * cr) >> 8));
631                let g = clamp_u8(yv - ((88 * cb + 183 * cr) >> 8));
632                let b = clamp_u8(yv + ((454 * cb) >> 8));
633
634                img.set_pixel(px, py, r, g, b, 255);
635            }
636        }
637    }
638
639    Ok(img)
640}
641
642/// Integer IDCT (based on AAN/LLM algorithm, all integer arithmetic).
643///
644/// Input: 64 dequantized DCT coefficients (zig-zag reordered to natural order).
645/// Output: 64 spatial-domain values (still needs +128 level shift).
646///
647/// Uses 13-bit fixed-point for intermediate results.
648fn idct_integer(block: &mut [i32; 64]) {
649    // Constants for the integer IDCT (scaled by 2^13)
650    // These approximate the exact cosine values without floating point.
651    const W1: i32 = 2841; // 2048*sqrt(2)*cos(1*pi/16)
652    const W2: i32 = 2676; // 2048*sqrt(2)*cos(2*pi/16)
653    const W3: i32 = 2408; // 2048*sqrt(2)*cos(3*pi/16)
654    const W5: i32 = 1609; // 2048*sqrt(2)*cos(5*pi/16)
655    const W6: i32 = 1108; // 2048*sqrt(2)*cos(6*pi/16)
656    const W7: i32 = 565; // 2048*sqrt(2)*cos(7*pi/16)
657
658    // 1D IDCT on rows
659    for i in 0..8 {
660        let row = i * 8;
661        // Check if row is all zeros (except DC)
662        if block[row + 1] == 0
663            && block[row + 2] == 0
664            && block[row + 3] == 0
665            && block[row + 4] == 0
666            && block[row + 5] == 0
667            && block[row + 6] == 0
668            && block[row + 7] == 0
669        {
670            let dc = block[row] << 3;
671            for j in 0..8 {
672                block[row + j] = dc;
673            }
674            continue;
675        }
676
677        // Stage: prescale
678        let mut x0 = (block[row] << 11) + 128;
679        let mut x1 = block[row + 4] << 11;
680        let x2 = block[row + 6];
681        let x3 = block[row + 2];
682        let x4 = block[row + 1];
683        let x5 = block[row + 7];
684        let x6 = block[row + 5];
685        let x7 = block[row + 3];
686
687        // Stage 1 -- even part
688        let x8 = W7 * (x4 + x5);
689        let mut x4r = x8 + (W1 - W7) * x4;
690        let mut x5r = x8 - (W1 + W7) * x5;
691        let x8 = W3 * (x6 + x7);
692        let mut x6r = x8 - (W3 - W5) * x6;
693        let mut x7r = x8 - (W3 + W5) * x7;
694
695        // Stage 2
696        x0 += x1;
697        x1 = x0 - (x1 << 1);
698        let x8_2 = W6 * (x2 + x3);
699        let x2r = x8_2 - (W2 + W6) * x2;
700        let x3r = x8_2 + (W2 - W6) * x3;
701        x4r += x6r;
702        x6r = x4r - (x6r << 1);
703        x5r += x7r;
704        x7r = x5r - (x7r << 1);
705
706        // Stage 3
707        x0 += x3r;
708        let x3s = x0 - (x3r << 1);
709        x1 += x2r;
710        let x2s = x1 - (x2r << 1);
711        let tmp = ((x6r + x7r) * 181 + 128) >> 8;
712        x6r = ((x6r - x7r) * 181 + 128) >> 8;
713
714        // Output
715        block[row] = (x0 + x4r) >> 8;
716        block[row + 1] = (x1 + tmp) >> 8;
717        block[row + 2] = (x2s + x6r) >> 8;
718        block[row + 3] = (x3s + x5r) >> 8;
719        block[row + 4] = (x3s - x5r) >> 8;
720        block[row + 5] = (x2s - x6r) >> 8;
721        block[row + 6] = (x1 - tmp) >> 8;
722        block[row + 7] = (x0 - x4r) >> 8;
723    }
724
725    // 1D IDCT on columns
726    for i in 0..8 {
727        // Check for all-zero column (except DC)
728        if block[8 + i] == 0
729            && block[16 + i] == 0
730            && block[24 + i] == 0
731            && block[32 + i] == 0
732            && block[40 + i] == 0
733            && block[48 + i] == 0
734            && block[56 + i] == 0
735        {
736            let dc = (block[i] + 32) >> 6;
737            for j in 0..8 {
738                block[j * 8 + i] = dc;
739            }
740            continue;
741        }
742
743        let mut x0 = (block[i] << 8) + 8192;
744        let mut x1 = block[32 + i] << 8;
745        let x2 = block[48 + i];
746        let x3 = block[16 + i];
747        let x4 = block[8 + i];
748        let x5 = block[56 + i];
749        let x6 = block[40 + i];
750        let x7 = block[24 + i];
751
752        let x8 = W7 * (x4 + x5) + 4;
753        let mut x4r = (x8 + (W1 - W7) * x4) >> 3;
754        let mut x5r = (x8 - (W1 + W7) * x5) >> 3;
755        let x8 = W3 * (x6 + x7) + 4;
756        let mut x6r = (x8 - (W3 - W5) * x6) >> 3;
757        let mut x7r = (x8 - (W3 + W5) * x7) >> 3;
758
759        x0 += x1;
760        x1 = x0 - (x1 << 1);
761        let x8_2 = W6 * (x2 + x3) + 4;
762        let x2r = (x8_2 - (W2 + W6) * x2) >> 3;
763        let x3r = (x8_2 + (W2 - W6) * x3) >> 3;
764        x4r += x6r;
765        x6r = x4r - (x6r << 1);
766        x5r += x7r;
767        x7r = x5r - (x7r << 1);
768
769        x0 += x3r;
770        let x3s = x0 - (x3r << 1);
771        x1 += x2r;
772        let x2s = x1 - (x2r << 1);
773        let tmp = ((x6r + x7r) * 181 + 128) >> 8;
774        x6r = ((x6r - x7r) * 181 + 128) >> 8;
775
776        block[i] = (x0 + x4r) >> 14;
777        block[8 + i] = (x1 + tmp) >> 14;
778        block[16 + i] = (x2s + x6r) >> 14;
779        block[24 + i] = (x3s + x5r) >> 14;
780        block[32 + i] = (x3s - x5r) >> 14;
781        block[40 + i] = (x2s - x6r) >> 14;
782        block[48 + i] = (x1 - tmp) >> 14;
783        block[56 + i] = (x0 - x4r) >> 14;
784    }
785}
786
787/// JPEG zig-zag scan order: maps linear index to natural order position.
788const ZIGZAG: [u8; 64] = [
789    0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20,
790    13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59,
791    52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63,
792];
793
794// ============================================================================
795// TESTS
796// ============================================================================
797
798#[cfg(test)]
799mod tests {
800    #[allow(unused_imports)]
801    use alloc::vec;
802
803    use super::*;
804
805    #[test]
806    fn test_jpeg_signature_check() {
807        let bad = [0x00, 0x00, 0xFF, 0xD8];
808        assert_eq!(decode_jpeg(&bad), Err(ImageCodecError::InvalidSignature));
809    }
810
811    #[test]
812    fn test_jpeg_too_short() {
813        let data = [0xFF, 0xD8];
814        // Should fail during parsing (no SOF0)
815        assert!(decode_jpeg(&data).is_err());
816    }
817
818    #[test]
819    fn test_jpeg_zigzag_order() {
820        // Verify first few entries
821        assert_eq!(ZIGZAG[0], 0);
822        assert_eq!(ZIGZAG[1], 1);
823        assert_eq!(ZIGZAG[2], 8);
824        assert_eq!(ZIGZAG[3], 16);
825        assert_eq!(ZIGZAG[4], 9);
826        assert_eq!(ZIGZAG[63], 63);
827    }
828
829    #[test]
830    fn test_jpeg_huff_table_build() {
831        let mut table = JpegHuffTable::default();
832        // 1 code of length 1
833        table.counts[1] = 1;
834        table.symbols = vec![0x42];
835        table.build();
836        assert_eq!(table.max_code[1], 0);
837        assert_eq!(table.val_offset[1], 0);
838    }
839
840    #[test]
841    fn test_idct_dc_only() {
842        // IDCT of a block with only DC coefficient should produce uniform output
843        let mut block = [0i32; 64];
844        block[0] = 100;
845        idct_integer(&mut block);
846        // All values should be the same (DC distributes evenly)
847        let dc_val = block[0];
848        for &v in &block {
849            // Allow small rounding differences
850            assert!((v - dc_val).abs() <= 1, "Expected ~{}, got {}", dc_val, v);
851        }
852    }
853}