⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/crypto/post_quantum/
kyber.rs

1//! ML-KEM (Kyber) Key Encapsulation Mechanism
2//!
3//! Implements lattice-based key encapsulation following NIST FIPS 203.
4//! Provides quantum-resistant key exchange at security levels 512, 768, 1024.
5
6use alloc::vec::Vec;
7
8use super::{
9    bit_reverse_7, expand_seed, kyber_barrett_reduce, mod_pow, KyberLevel, KYBER_N, KYBER_Q,
10    KYBER_ZETA,
11};
12use crate::crypto::CryptoResult;
13
14// ============================================================================
15// NTT Constants and Helpers
16// ============================================================================
17
18/// Montgomery parameter for Kyber: R = 2^16 mod q
19const KYBER_MONT_R: u32 = 2285; // 2^16 mod 3329
20
21/// Barrett reduction constant for Kyber
22const KYBER_BARRETT_V: u32 = 20159; // round(2^26 / q)
23
24/// Precomputed NTT zetas (powers of the root of unity) for Kyber
25/// zetas[i] = KYBER_ZETA^(bit_reverse_7(i)) mod q
26fn kyber_ntt_zetas() -> [u16; 128] {
27    let mut zetas = [0u16; 128];
28    let mut i = 0;
29    while i < 128 {
30        let rev = bit_reverse_7(i as u8) as u32;
31        zetas[i] = mod_pow(KYBER_ZETA, rev, KYBER_Q) as u16;
32        i += 1;
33    }
34    zetas
35}
36
37// ============================================================================
38// Kyber Polynomial
39// ============================================================================
40
41/// Kyber polynomial (degree N-1 with coefficients mod q)
42#[derive(Clone)]
43pub(super) struct KyberPoly {
44    coeffs: [i16; KYBER_N],
45}
46
47impl KyberPoly {
48    pub(super) fn zero() -> Self {
49        KyberPoly {
50            coeffs: [0i16; KYBER_N],
51        }
52    }
53
54    /// Forward NTT (number theoretic transform) in place
55    /// Transforms polynomial from normal form to NTT domain
56    pub(super) fn ntt(&mut self) {
57        let zetas = kyber_ntt_zetas();
58        let mut k = 1usize;
59        let mut len = 128;
60        while len >= 2 {
61            let mut start = 0;
62            while start < KYBER_N {
63                let zeta = zetas[k] as i32;
64                k += 1;
65                let mut j = start;
66                while j < start + len {
67                    let t = ((zeta as i64 * self.coeffs[j + len] as i64) % KYBER_Q as i64) as i32;
68                    let t = kyber_barrett_reduce(t) as i16;
69                    self.coeffs[j + len] =
70                        kyber_barrett_reduce(self.coeffs[j] as i32 - t as i32) as i16;
71                    self.coeffs[j] = kyber_barrett_reduce(self.coeffs[j] as i32 + t as i32) as i16;
72                    j += 1;
73                }
74                start += 2 * len;
75            }
76            len >>= 1;
77        }
78    }
79
80    /// Inverse NTT
81    pub(super) fn inv_ntt(&mut self) {
82        let zetas = kyber_ntt_zetas();
83        let mut k = 127usize;
84        let mut len = 2;
85        while len <= 128 {
86            let mut start = 0;
87            while start < KYBER_N {
88                let zeta = zetas[k] as i32;
89                k = k.wrapping_sub(1);
90                let mut j = start;
91                while j < start + len {
92                    let t = self.coeffs[j] as i32;
93                    self.coeffs[j] = kyber_barrett_reduce(t + self.coeffs[j + len] as i32) as i16;
94                    let diff = t - self.coeffs[j + len] as i32;
95                    self.coeffs[j + len] =
96                        kyber_barrett_reduce(((zeta as i64 * diff as i64) % KYBER_Q as i64) as i32)
97                            as i16;
98                    j += 1;
99                }
100                start += 2 * len;
101            }
102            len <<= 1;
103        }
104
105        // Multiply by N^(-1) mod q
106        let n_inv = mod_pow(256, KYBER_Q - 2, KYBER_Q) as i32;
107        let mut i = 0;
108        while i < KYBER_N {
109            self.coeffs[i] = kyber_barrett_reduce(
110                ((self.coeffs[i] as i64 * n_inv as i64) % KYBER_Q as i64) as i32,
111            ) as i16;
112            i += 1;
113        }
114    }
115
116    /// Pointwise multiplication in NTT domain
117    pub(super) fn pointwise_mul(&self, other: &KyberPoly) -> KyberPoly {
118        let mut result = KyberPoly::zero();
119        let mut i = 0;
120        while i < KYBER_N {
121            result.coeffs[i] = kyber_barrett_reduce(
122                ((self.coeffs[i] as i64 * other.coeffs[i] as i64) % KYBER_Q as i64) as i32,
123            ) as i16;
124            i += 1;
125        }
126        result
127    }
128
129    /// Add two polynomials
130    pub(super) fn add(&self, other: &KyberPoly) -> KyberPoly {
131        let mut result = KyberPoly::zero();
132        let mut i = 0;
133        while i < KYBER_N {
134            result.coeffs[i] =
135                kyber_barrett_reduce(self.coeffs[i] as i32 + other.coeffs[i] as i32) as i16;
136            i += 1;
137        }
138        result
139    }
140
141    /// Subtract two polynomials
142    pub(super) fn sub(&self, other: &KyberPoly) -> KyberPoly {
143        let mut result = KyberPoly::zero();
144        let mut i = 0;
145        while i < KYBER_N {
146            result.coeffs[i] =
147                kyber_barrett_reduce(self.coeffs[i] as i32 - other.coeffs[i] as i32) as i16;
148            i += 1;
149        }
150        result
151    }
152
153    /// Sample polynomial with small coefficients from CBD (centered binomial
154    /// distribution)
155    pub(super) fn sample_cbd(seed: &[u8], eta: u32) -> KyberPoly {
156        let mut poly = KyberPoly::zero();
157        let expanded = expand_seed(seed, KYBER_N * eta as usize / 4);
158
159        let mut i = 0;
160        let mut byte_idx = 0;
161        while i < KYBER_N && byte_idx + (eta as usize) <= expanded.len() {
162            let mut a: i16 = 0;
163            let mut b: i16 = 0;
164            let mut j = 0u32;
165            while j < eta {
166                if byte_idx < expanded.len() {
167                    a += ((expanded[byte_idx] >> (j & 7)) & 1) as i16;
168                    b += ((expanded[byte_idx] >> ((j + eta) & 7)) & 1) as i16;
169                }
170                j += 1;
171            }
172            poly.coeffs[i] = kyber_barrett_reduce((a - b) as i32) as i16;
173            byte_idx += 1;
174            i += 1;
175        }
176        poly
177    }
178
179    /// Sample a uniform polynomial from seed (for matrix A)
180    pub(super) fn sample_uniform(seed: &[u8]) -> KyberPoly {
181        let mut poly = KyberPoly::zero();
182        let expanded = expand_seed(seed, KYBER_N * 4);
183
184        let mut i = 0;
185        let mut j = 0;
186        while i < KYBER_N && j + 2 < expanded.len() {
187            let d = (expanded[j] as u16) | ((expanded[j + 1] as u16) << 8);
188            let d = (d & 0x0fff) as i32; // 12 bits
189            if d < KYBER_Q as i32 {
190                poly.coeffs[i] = d as i16;
191                i += 1;
192            }
193            j += 2;
194        }
195        poly
196    }
197
198    /// Compress coefficients to d bits
199    pub(super) fn compress(&self, d: usize) -> Vec<u8> {
200        let mut result = Vec::new();
201        let q = KYBER_Q as u64;
202        let mut bits_buf: u32 = 0;
203        let mut bits_count: u32 = 0;
204
205        let mut i = 0;
206        while i < KYBER_N {
207            let mut c = self.coeffs[i] as i32;
208            if c < 0 {
209                c += KYBER_Q as i32;
210            }
211            // Compress: round(2^d / q * c) mod 2^d
212            let compressed = (((c as u64) << d as u64).wrapping_add(q / 2) / q) & ((1u64 << d) - 1);
213            bits_buf |= (compressed as u32) << bits_count;
214            bits_count += d as u32;
215            while bits_count >= 8 {
216                result.push(bits_buf as u8);
217                bits_buf >>= 8;
218                bits_count -= 8;
219            }
220            i += 1;
221        }
222        if bits_count > 0 {
223            result.push(bits_buf as u8);
224        }
225        result
226    }
227
228    /// Decompress from d bits
229    pub(super) fn decompress(data: &[u8], d: usize) -> KyberPoly {
230        let mut poly = KyberPoly::zero();
231        let q = KYBER_Q as u64;
232        let mask = (1u32 << d) - 1;
233        let mut bits_buf: u32 = 0;
234        let mut bits_count: u32 = 0;
235        let mut byte_idx = 0;
236
237        let mut i = 0;
238        while i < KYBER_N {
239            while bits_count < d as u32 && byte_idx < data.len() {
240                bits_buf |= (data[byte_idx] as u32) << bits_count;
241                bits_count += 8;
242                byte_idx += 1;
243            }
244            let val = bits_buf & mask;
245            bits_buf >>= d;
246            bits_count -= d as u32;
247            // Decompress: round(q / 2^d * val)
248            poly.coeffs[i] = ((val as u64 * q + (1u64 << (d - 1))) >> d) as i16;
249            i += 1;
250        }
251        poly
252    }
253
254    /// Encode polynomial coefficients to bytes (12 bits each)
255    pub(super) fn to_bytes(&self) -> Vec<u8> {
256        let mut result = Vec::with_capacity(KYBER_N * 12 / 8);
257        let mut i = 0;
258        while i < KYBER_N {
259            let mut c0 = self.coeffs[i] as i32;
260            if c0 < 0 {
261                c0 += KYBER_Q as i32;
262            }
263            let c0 = c0 as u16;
264
265            if i + 1 < KYBER_N {
266                let mut c1 = self.coeffs[i + 1] as i32;
267                if c1 < 0 {
268                    c1 += KYBER_Q as i32;
269                }
270                let c1 = c1 as u16;
271
272                result.push(c0 as u8);
273                result.push(((c0 >> 8) | (c1 << 4)) as u8);
274                result.push((c1 >> 4) as u8);
275            } else {
276                result.push(c0 as u8);
277                result.push((c0 >> 8) as u8);
278            }
279            i += 2;
280        }
281        result
282    }
283
284    /// Decode polynomial from bytes (12 bits per coefficient)
285    pub(super) fn from_bytes(data: &[u8]) -> KyberPoly {
286        let mut poly = KyberPoly::zero();
287        let mut i = 0;
288        let mut j = 0;
289        while i < KYBER_N && j + 2 < data.len() {
290            poly.coeffs[i] = (data[j] as i16) | ((data[j + 1] as i16 & 0x0f) << 8);
291            if i + 1 < KYBER_N {
292                poly.coeffs[i + 1] = (data[j + 1] as i16 >> 4) | ((data[j + 2] as i16) << 4);
293            }
294            i += 2;
295            j += 3;
296        }
297        poly
298    }
299}
300
301// ============================================================================
302// ML-KEM (Kyber) Public Types
303// ============================================================================
304
305/// Get Kyber parameters
306fn kyber_params(level: KyberLevel) -> (usize, u32, u32, usize, usize) {
307    // Returns (k, eta1, eta2, du, dv)
308    match level {
309        KyberLevel::Kyber512 => (2, 3, 2, 10, 4),
310        KyberLevel::Kyber768 => (3, 2, 2, 10, 4),
311        KyberLevel::Kyber1024 => (4, 2, 2, 11, 5),
312    }
313}
314
315/// ML-KEM (Kyber) secret key
316pub(crate) struct KyberSecretKey {
317    level: KyberLevel,
318    secret: Vec<u8>,
319}
320
321/// ML-KEM (Kyber) public key
322pub(crate) struct KyberPublicKey {
323    level: KyberLevel,
324    public: Vec<u8>,
325}
326
327/// ML-KEM (Kyber) ciphertext
328pub(crate) struct KyberCiphertext {
329    bytes: Vec<u8>,
330}
331
332/// ML-KEM (Kyber) shared secret
333pub(crate) struct KyberSharedSecret {
334    bytes: [u8; 32],
335}
336
337impl KyberSecretKey {
338    /// Generate new key pair using lattice-based key generation
339    pub(crate) fn generate(level: KyberLevel) -> CryptoResult<Self> {
340        use crate::crypto::{hash::sha256, random::get_random};
341
342        let (k, eta1, _eta2, _du, _dv) = kyber_params(level);
343        let rng = get_random();
344
345        // Generate random seed
346        let mut seed = [0u8; 32];
347        rng.fill_bytes(&mut seed)?;
348
349        // Expand seed for deterministic generation
350        let rho_sigma = {
351            let mut input = Vec::from(seed.as_slice());
352            input.push(k as u8);
353            let hash = crate::crypto::hash::sha512(&input);
354            *hash.as_bytes()
355        };
356        let rho = &rho_sigma[..32]; // For matrix A
357        let sigma = &rho_sigma[32..64]; // For secret/noise
358
359        // Generate matrix A (k x k polynomials in NTT domain)
360        let mut a_hat = Vec::new();
361        let mut i = 0;
362        while i < k * k {
363            let mut a_seed = Vec::new();
364            a_seed.extend_from_slice(rho);
365            a_seed.push((i % k) as u8);
366            a_seed.push((i / k) as u8);
367            let hash = sha256(&a_seed);
368            let mut poly = KyberPoly::sample_uniform(hash.as_bytes());
369            poly.ntt();
370            a_hat.push(poly);
371            i += 1;
372        }
373
374        // Generate secret vector s (k polynomials with small coefficients)
375        let mut s = Vec::new();
376        i = 0;
377        while i < k {
378            let mut s_seed = Vec::new();
379            s_seed.extend_from_slice(sigma);
380            s_seed.push(i as u8);
381            let hash = sha256(&s_seed);
382            let poly = KyberPoly::sample_cbd(hash.as_bytes(), eta1);
383            s.push(poly);
384            i += 1;
385        }
386
387        // Generate error vector e (k polynomials)
388        let mut e = Vec::new();
389        i = 0;
390        while i < k {
391            let mut e_seed = Vec::new();
392            e_seed.extend_from_slice(sigma);
393            e_seed.push((k + i) as u8);
394            let hash = sha256(&e_seed);
395            let poly = KyberPoly::sample_cbd(hash.as_bytes(), eta1);
396            e.push(poly);
397            i += 1;
398        }
399
400        // Compute NTT of s
401        let mut s_hat = Vec::new();
402        for si in &s {
403            let mut si_ntt = si.clone();
404            si_ntt.ntt();
405            s_hat.push(si_ntt);
406        }
407
408        // Compute NTT of e
409        let mut e_hat = Vec::new();
410        for ei in &e {
411            let mut ei_ntt = ei.clone();
412            ei_ntt.ntt();
413            e_hat.push(ei_ntt);
414        }
415
416        // Compute t_hat = A_hat * s_hat + e_hat
417        let mut t_hat = Vec::new();
418        i = 0;
419        while i < k {
420            let mut ti = KyberPoly::zero();
421            let mut j = 0;
422            while j < k {
423                let product = a_hat[i * k + j].pointwise_mul(&s_hat[j]);
424                ti = ti.add(&product);
425                j += 1;
426            }
427            ti = ti.add(&e_hat[i]);
428            t_hat.push(ti);
429            i += 1;
430        }
431
432        // Serialize secret key: seed || s || pk
433        let mut secret_data = Vec::new();
434        secret_data.extend_from_slice(&seed);
435        // Store s in NTT domain
436        for si in &s_hat {
437            secret_data.extend_from_slice(&si.to_bytes());
438        }
439        // Store hash of public key for decapsulation
440        let mut pk_bytes = Vec::new();
441        pk_bytes.extend_from_slice(rho);
442        for ti in &t_hat {
443            pk_bytes.extend_from_slice(&ti.to_bytes());
444        }
445        let pk_hash = sha256(&pk_bytes);
446        secret_data.extend_from_slice(pk_hash.as_bytes());
447        // Store rho for matrix reconstruction
448        secret_data.extend_from_slice(rho);
449        // Store t_hat for public key derivation
450        for ti in &t_hat {
451            secret_data.extend_from_slice(&ti.to_bytes());
452        }
453
454        Ok(Self {
455            level,
456            secret: secret_data,
457        })
458    }
459
460    /// Get corresponding public key
461    pub(crate) fn public_key(&self) -> KyberPublicKey {
462        let (k, _, _, _, _) = kyber_params(self.level);
463
464        let public_size = match self.level {
465            KyberLevel::Kyber512 => 800,
466            KyberLevel::Kyber768 => 1184,
467            KyberLevel::Kyber1024 => 1568,
468        };
469
470        // Public key = rho || t_hat
471        // rho is at offset: 32 (seed) + k * poly_bytes_full (s_hat) + 32 (pk_hash)
472        let s_hat_bytes = k * KYBER_N * 12 / 8;
473        let rho_offset = 32 + s_hat_bytes + 32;
474
475        let mut public = Vec::with_capacity(public_size);
476
477        // Extract rho and t_hat from secret key
478        if rho_offset + 32 <= self.secret.len() {
479            public.extend_from_slice(&self.secret[rho_offset..rho_offset + 32]);
480        } else {
481            public.extend_from_slice(&[0u8; 32]);
482        }
483
484        // t_hat follows rho
485        let t_start = rho_offset + 32;
486        let mut i = 0;
487        while i < k {
488            let offset = t_start + i * (KYBER_N * 12 / 8);
489            if offset + KYBER_N * 12 / 8 <= self.secret.len() {
490                public.extend_from_slice(&self.secret[offset..offset + KYBER_N * 12 / 8]);
491            }
492            i += 1;
493        }
494
495        // Pad to expected size
496        while public.len() < public_size {
497            public.push(0);
498        }
499        public.truncate(public_size);
500
501        KyberPublicKey {
502            level: self.level,
503            public,
504        }
505    }
506
507    /// Decapsulate to get shared secret
508    pub(crate) fn decapsulate(
509        &self,
510        ciphertext: &KyberCiphertext,
511    ) -> CryptoResult<KyberSharedSecret> {
512        use crate::crypto::hash::sha256;
513
514        let (k, _, _, du, dv) = kyber_params(self.level);
515
516        // Extract s_hat from secret key
517        let poly_bytes = KYBER_N * 12 / 8;
518        let mut s_hat = Vec::new();
519        let mut i = 0;
520        while i < k {
521            let offset = 32 + i * poly_bytes;
522            if offset + poly_bytes <= self.secret.len() {
523                s_hat.push(KyberPoly::from_bytes(
524                    &self.secret[offset..offset + poly_bytes],
525                ));
526            } else {
527                s_hat.push(KyberPoly::zero());
528            }
529            i += 1;
530        }
531
532        // Parse ciphertext: u (k compressed polynomials) || v (1 compressed polynomial)
533        let u_bytes_per_poly = KYBER_N * du / 8;
534        let v_bytes = KYBER_N * dv / 8;
535        let mut u = Vec::new();
536        i = 0;
537        while i < k {
538            let offset = i * u_bytes_per_poly;
539            if offset + u_bytes_per_poly <= ciphertext.bytes.len() {
540                let mut ui =
541                    KyberPoly::decompress(&ciphertext.bytes[offset..offset + u_bytes_per_poly], du);
542                ui.ntt();
543                u.push(ui);
544            } else {
545                u.push(KyberPoly::zero());
546            }
547            i += 1;
548        }
549
550        let v_offset = k * u_bytes_per_poly;
551        let v = if v_offset + v_bytes <= ciphertext.bytes.len() {
552            KyberPoly::decompress(&ciphertext.bytes[v_offset..v_offset + v_bytes], dv)
553        } else {
554            KyberPoly::zero()
555        };
556
557        // Compute m' = v - s^T * u
558        let mut su = KyberPoly::zero();
559        i = 0;
560        while i < k {
561            let product = s_hat[i].pointwise_mul(&u[i]);
562            su = su.add(&product);
563            i += 1;
564        }
565        su.inv_ntt();
566
567        let m_prime = v.sub(&su);
568
569        // Decode message from polynomial (each coefficient encodes 1 bit)
570        let mut msg = [0u8; 32];
571        let mut bi = 0;
572        while bi < 256 && bi / 8 < 32 {
573            let mut coeff = m_prime.coeffs[bi] as i32;
574            if coeff < 0 {
575                coeff += KYBER_Q as i32;
576            }
577            // Round to nearest: if coeff > q/2, bit = 1; else bit = 0
578            let bit = if coeff > KYBER_Q as i32 / 2 { 1u8 } else { 0u8 };
579            msg[bi / 8] |= bit << (bi % 8);
580            bi += 1;
581        }
582
583        // Derive shared secret from message using hash
584        let mut hash_input = Vec::new();
585        hash_input.extend_from_slice(&msg);
586        hash_input.extend_from_slice(&ciphertext.bytes);
587        let shared = sha256(&hash_input);
588
589        Ok(KyberSharedSecret {
590            bytes: *shared.as_bytes(),
591        })
592    }
593}
594
595impl KyberPublicKey {
596    /// Encapsulate to generate shared secret and ciphertext
597    pub(crate) fn encapsulate(&self) -> CryptoResult<(KyberCiphertext, KyberSharedSecret)> {
598        use crate::crypto::{hash::sha256, random::get_random};
599
600        let (k, _eta1, eta2, du, dv) = kyber_params(self.level);
601        let rng = get_random();
602
603        // Generate random message
604        let mut msg = [0u8; 32];
605        rng.fill_bytes(&mut msg)?;
606
607        // Derive coins from message and public key hash
608        let pk_hash = sha256(&self.public);
609        let mut coin_input = Vec::new();
610        coin_input.extend_from_slice(&msg);
611        coin_input.extend_from_slice(pk_hash.as_bytes());
612        let coins = sha256(&coin_input);
613
614        // Extract rho and t_hat from public key
615        let rho = if self.public.len() >= 32 {
616            &self.public[..32]
617        } else {
618            &[0u8; 32]
619        };
620        let poly_bytes = KYBER_N * 12 / 8;
621
622        let mut t_hat = Vec::new();
623        let mut i = 0;
624        while i < k {
625            let offset = 32 + i * poly_bytes;
626            if offset + poly_bytes <= self.public.len() {
627                t_hat.push(KyberPoly::from_bytes(
628                    &self.public[offset..offset + poly_bytes],
629                ));
630            } else {
631                t_hat.push(KyberPoly::zero());
632            }
633            i += 1;
634        }
635
636        // Reconstruct matrix A_hat
637        let mut a_hat = Vec::new();
638        i = 0;
639        while i < k * k {
640            let mut a_seed = Vec::new();
641            a_seed.extend_from_slice(rho);
642            a_seed.push((i % k) as u8);
643            a_seed.push((i / k) as u8);
644            let hash = sha256(&a_seed);
645            let mut poly = KyberPoly::sample_uniform(hash.as_bytes());
646            poly.ntt();
647            a_hat.push(poly);
648            i += 1;
649        }
650
651        // Generate random vectors r, e1, e2
652        let mut r = Vec::new();
653        i = 0;
654        while i < k {
655            let mut r_seed = Vec::new();
656            r_seed.extend_from_slice(coins.as_bytes());
657            r_seed.push(i as u8);
658            let hash = sha256(&r_seed);
659            let mut poly = KyberPoly::sample_cbd(hash.as_bytes(), eta2);
660            poly.ntt();
661            r.push(poly);
662            i += 1;
663        }
664
665        let mut e1 = Vec::new();
666        i = 0;
667        while i < k {
668            let mut e1_seed = Vec::new();
669            e1_seed.extend_from_slice(coins.as_bytes());
670            e1_seed.push((k + i) as u8);
671            let hash = sha256(&e1_seed);
672            e1.push(KyberPoly::sample_cbd(hash.as_bytes(), eta2));
673            i += 1;
674        }
675
676        let mut e2_seed = Vec::new();
677        e2_seed.extend_from_slice(coins.as_bytes());
678        e2_seed.push((2 * k) as u8);
679        let e2_hash = sha256(&e2_seed);
680        let e2 = KyberPoly::sample_cbd(e2_hash.as_bytes(), eta2);
681
682        // Compute u = A^T * r + e1
683        let mut u = Vec::new();
684        i = 0;
685        while i < k {
686            let mut ui = KyberPoly::zero();
687            let mut j = 0;
688            while j < k {
689                // A^T: swap indices
690                let product = a_hat[j * k + i].pointwise_mul(&r[j]);
691                ui = ui.add(&product);
692                j += 1;
693            }
694            ui.inv_ntt();
695            ui = ui.add(&e1[i]);
696            u.push(ui);
697            i += 1;
698        }
699
700        // Compute v = t^T * r + e2 + encode(msg)
701        let mut v = KyberPoly::zero();
702        i = 0;
703        while i < k {
704            let product = t_hat[i].pointwise_mul(&r[i]);
705            v = v.add(&product);
706            i += 1;
707        }
708        v.inv_ntt();
709        v = v.add(&e2);
710
711        // Encode message into polynomial (each bit becomes q/2 or 0)
712        let mut msg_poly = KyberPoly::zero();
713        let mut bi = 0;
714        while bi < 256 && bi / 8 < 32 {
715            let bit = (msg[bi / 8] >> (bi % 8)) & 1;
716            if bit == 1 {
717                msg_poly.coeffs[bi] = KYBER_Q.div_ceil(2) as i16;
718            }
719            bi += 1;
720        }
721        v = v.add(&msg_poly);
722
723        // Build ciphertext: compress(u) || compress(v)
724        let mut ct_bytes = Vec::new();
725        for ui in &u {
726            ct_bytes.extend_from_slice(&ui.compress(du));
727        }
728        ct_bytes.extend_from_slice(&v.compress(dv));
729
730        // Derive shared secret
731        let mut hash_input = Vec::new();
732        hash_input.extend_from_slice(&msg);
733        hash_input.extend_from_slice(&ct_bytes);
734        let shared = sha256(&hash_input);
735
736        Ok((
737            KyberCiphertext { bytes: ct_bytes },
738            KyberSharedSecret {
739                bytes: *shared.as_bytes(),
740            },
741        ))
742    }
743}
744
745impl KyberSharedSecret {
746    /// Get shared secret bytes
747    pub(crate) fn as_bytes(&self) -> &[u8; 32] {
748        &self.bytes
749    }
750}