veridian_kernel/crypto/post_quantum/
kyber.rs1use alloc::vec::Vec;
7
8use super::{
9 bit_reverse_7, expand_seed, kyber_barrett_reduce, mod_pow, KyberLevel, KYBER_N, KYBER_Q,
10 KYBER_ZETA,
11};
12use crate::crypto::CryptoResult;
13
14const KYBER_MONT_R: u32 = 2285; const KYBER_BARRETT_V: u32 = 20159; fn kyber_ntt_zetas() -> [u16; 128] {
27 let mut zetas = [0u16; 128];
28 let mut i = 0;
29 while i < 128 {
30 let rev = bit_reverse_7(i as u8) as u32;
31 zetas[i] = mod_pow(KYBER_ZETA, rev, KYBER_Q) as u16;
32 i += 1;
33 }
34 zetas
35}
36
37#[derive(Clone)]
43pub(super) struct KyberPoly {
44 coeffs: [i16; KYBER_N],
45}
46
47impl KyberPoly {
48 pub(super) fn zero() -> Self {
49 KyberPoly {
50 coeffs: [0i16; KYBER_N],
51 }
52 }
53
54 pub(super) fn ntt(&mut self) {
57 let zetas = kyber_ntt_zetas();
58 let mut k = 1usize;
59 let mut len = 128;
60 while len >= 2 {
61 let mut start = 0;
62 while start < KYBER_N {
63 let zeta = zetas[k] as i32;
64 k += 1;
65 let mut j = start;
66 while j < start + len {
67 let t = ((zeta as i64 * self.coeffs[j + len] as i64) % KYBER_Q as i64) as i32;
68 let t = kyber_barrett_reduce(t) as i16;
69 self.coeffs[j + len] =
70 kyber_barrett_reduce(self.coeffs[j] as i32 - t as i32) as i16;
71 self.coeffs[j] = kyber_barrett_reduce(self.coeffs[j] as i32 + t as i32) as i16;
72 j += 1;
73 }
74 start += 2 * len;
75 }
76 len >>= 1;
77 }
78 }
79
80 pub(super) fn inv_ntt(&mut self) {
82 let zetas = kyber_ntt_zetas();
83 let mut k = 127usize;
84 let mut len = 2;
85 while len <= 128 {
86 let mut start = 0;
87 while start < KYBER_N {
88 let zeta = zetas[k] as i32;
89 k = k.wrapping_sub(1);
90 let mut j = start;
91 while j < start + len {
92 let t = self.coeffs[j] as i32;
93 self.coeffs[j] = kyber_barrett_reduce(t + self.coeffs[j + len] as i32) as i16;
94 let diff = t - self.coeffs[j + len] as i32;
95 self.coeffs[j + len] =
96 kyber_barrett_reduce(((zeta as i64 * diff as i64) % KYBER_Q as i64) as i32)
97 as i16;
98 j += 1;
99 }
100 start += 2 * len;
101 }
102 len <<= 1;
103 }
104
105 let n_inv = mod_pow(256, KYBER_Q - 2, KYBER_Q) as i32;
107 let mut i = 0;
108 while i < KYBER_N {
109 self.coeffs[i] = kyber_barrett_reduce(
110 ((self.coeffs[i] as i64 * n_inv as i64) % KYBER_Q as i64) as i32,
111 ) as i16;
112 i += 1;
113 }
114 }
115
116 pub(super) fn pointwise_mul(&self, other: &KyberPoly) -> KyberPoly {
118 let mut result = KyberPoly::zero();
119 let mut i = 0;
120 while i < KYBER_N {
121 result.coeffs[i] = kyber_barrett_reduce(
122 ((self.coeffs[i] as i64 * other.coeffs[i] as i64) % KYBER_Q as i64) as i32,
123 ) as i16;
124 i += 1;
125 }
126 result
127 }
128
129 pub(super) fn add(&self, other: &KyberPoly) -> KyberPoly {
131 let mut result = KyberPoly::zero();
132 let mut i = 0;
133 while i < KYBER_N {
134 result.coeffs[i] =
135 kyber_barrett_reduce(self.coeffs[i] as i32 + other.coeffs[i] as i32) as i16;
136 i += 1;
137 }
138 result
139 }
140
141 pub(super) fn sub(&self, other: &KyberPoly) -> KyberPoly {
143 let mut result = KyberPoly::zero();
144 let mut i = 0;
145 while i < KYBER_N {
146 result.coeffs[i] =
147 kyber_barrett_reduce(self.coeffs[i] as i32 - other.coeffs[i] as i32) as i16;
148 i += 1;
149 }
150 result
151 }
152
153 pub(super) fn sample_cbd(seed: &[u8], eta: u32) -> KyberPoly {
156 let mut poly = KyberPoly::zero();
157 let expanded = expand_seed(seed, KYBER_N * eta as usize / 4);
158
159 let mut i = 0;
160 let mut byte_idx = 0;
161 while i < KYBER_N && byte_idx + (eta as usize) <= expanded.len() {
162 let mut a: i16 = 0;
163 let mut b: i16 = 0;
164 let mut j = 0u32;
165 while j < eta {
166 if byte_idx < expanded.len() {
167 a += ((expanded[byte_idx] >> (j & 7)) & 1) as i16;
168 b += ((expanded[byte_idx] >> ((j + eta) & 7)) & 1) as i16;
169 }
170 j += 1;
171 }
172 poly.coeffs[i] = kyber_barrett_reduce((a - b) as i32) as i16;
173 byte_idx += 1;
174 i += 1;
175 }
176 poly
177 }
178
179 pub(super) fn sample_uniform(seed: &[u8]) -> KyberPoly {
181 let mut poly = KyberPoly::zero();
182 let expanded = expand_seed(seed, KYBER_N * 4);
183
184 let mut i = 0;
185 let mut j = 0;
186 while i < KYBER_N && j + 2 < expanded.len() {
187 let d = (expanded[j] as u16) | ((expanded[j + 1] as u16) << 8);
188 let d = (d & 0x0fff) as i32; if d < KYBER_Q as i32 {
190 poly.coeffs[i] = d as i16;
191 i += 1;
192 }
193 j += 2;
194 }
195 poly
196 }
197
198 pub(super) fn compress(&self, d: usize) -> Vec<u8> {
200 let mut result = Vec::new();
201 let q = KYBER_Q as u64;
202 let mut bits_buf: u32 = 0;
203 let mut bits_count: u32 = 0;
204
205 let mut i = 0;
206 while i < KYBER_N {
207 let mut c = self.coeffs[i] as i32;
208 if c < 0 {
209 c += KYBER_Q as i32;
210 }
211 let compressed = (((c as u64) << d as u64).wrapping_add(q / 2) / q) & ((1u64 << d) - 1);
213 bits_buf |= (compressed as u32) << bits_count;
214 bits_count += d as u32;
215 while bits_count >= 8 {
216 result.push(bits_buf as u8);
217 bits_buf >>= 8;
218 bits_count -= 8;
219 }
220 i += 1;
221 }
222 if bits_count > 0 {
223 result.push(bits_buf as u8);
224 }
225 result
226 }
227
228 pub(super) fn decompress(data: &[u8], d: usize) -> KyberPoly {
230 let mut poly = KyberPoly::zero();
231 let q = KYBER_Q as u64;
232 let mask = (1u32 << d) - 1;
233 let mut bits_buf: u32 = 0;
234 let mut bits_count: u32 = 0;
235 let mut byte_idx = 0;
236
237 let mut i = 0;
238 while i < KYBER_N {
239 while bits_count < d as u32 && byte_idx < data.len() {
240 bits_buf |= (data[byte_idx] as u32) << bits_count;
241 bits_count += 8;
242 byte_idx += 1;
243 }
244 let val = bits_buf & mask;
245 bits_buf >>= d;
246 bits_count -= d as u32;
247 poly.coeffs[i] = ((val as u64 * q + (1u64 << (d - 1))) >> d) as i16;
249 i += 1;
250 }
251 poly
252 }
253
254 pub(super) fn to_bytes(&self) -> Vec<u8> {
256 let mut result = Vec::with_capacity(KYBER_N * 12 / 8);
257 let mut i = 0;
258 while i < KYBER_N {
259 let mut c0 = self.coeffs[i] as i32;
260 if c0 < 0 {
261 c0 += KYBER_Q as i32;
262 }
263 let c0 = c0 as u16;
264
265 if i + 1 < KYBER_N {
266 let mut c1 = self.coeffs[i + 1] as i32;
267 if c1 < 0 {
268 c1 += KYBER_Q as i32;
269 }
270 let c1 = c1 as u16;
271
272 result.push(c0 as u8);
273 result.push(((c0 >> 8) | (c1 << 4)) as u8);
274 result.push((c1 >> 4) as u8);
275 } else {
276 result.push(c0 as u8);
277 result.push((c0 >> 8) as u8);
278 }
279 i += 2;
280 }
281 result
282 }
283
284 pub(super) fn from_bytes(data: &[u8]) -> KyberPoly {
286 let mut poly = KyberPoly::zero();
287 let mut i = 0;
288 let mut j = 0;
289 while i < KYBER_N && j + 2 < data.len() {
290 poly.coeffs[i] = (data[j] as i16) | ((data[j + 1] as i16 & 0x0f) << 8);
291 if i + 1 < KYBER_N {
292 poly.coeffs[i + 1] = (data[j + 1] as i16 >> 4) | ((data[j + 2] as i16) << 4);
293 }
294 i += 2;
295 j += 3;
296 }
297 poly
298 }
299}
300
301fn kyber_params(level: KyberLevel) -> (usize, u32, u32, usize, usize) {
307 match level {
309 KyberLevel::Kyber512 => (2, 3, 2, 10, 4),
310 KyberLevel::Kyber768 => (3, 2, 2, 10, 4),
311 KyberLevel::Kyber1024 => (4, 2, 2, 11, 5),
312 }
313}
314
315pub(crate) struct KyberSecretKey {
317 level: KyberLevel,
318 secret: Vec<u8>,
319}
320
321pub(crate) struct KyberPublicKey {
323 level: KyberLevel,
324 public: Vec<u8>,
325}
326
327pub(crate) struct KyberCiphertext {
329 bytes: Vec<u8>,
330}
331
332pub(crate) struct KyberSharedSecret {
334 bytes: [u8; 32],
335}
336
337impl KyberSecretKey {
338 pub(crate) fn generate(level: KyberLevel) -> CryptoResult<Self> {
340 use crate::crypto::{hash::sha256, random::get_random};
341
342 let (k, eta1, _eta2, _du, _dv) = kyber_params(level);
343 let rng = get_random();
344
345 let mut seed = [0u8; 32];
347 rng.fill_bytes(&mut seed)?;
348
349 let rho_sigma = {
351 let mut input = Vec::from(seed.as_slice());
352 input.push(k as u8);
353 let hash = crate::crypto::hash::sha512(&input);
354 *hash.as_bytes()
355 };
356 let rho = &rho_sigma[..32]; let sigma = &rho_sigma[32..64]; let mut a_hat = Vec::new();
361 let mut i = 0;
362 while i < k * k {
363 let mut a_seed = Vec::new();
364 a_seed.extend_from_slice(rho);
365 a_seed.push((i % k) as u8);
366 a_seed.push((i / k) as u8);
367 let hash = sha256(&a_seed);
368 let mut poly = KyberPoly::sample_uniform(hash.as_bytes());
369 poly.ntt();
370 a_hat.push(poly);
371 i += 1;
372 }
373
374 let mut s = Vec::new();
376 i = 0;
377 while i < k {
378 let mut s_seed = Vec::new();
379 s_seed.extend_from_slice(sigma);
380 s_seed.push(i as u8);
381 let hash = sha256(&s_seed);
382 let poly = KyberPoly::sample_cbd(hash.as_bytes(), eta1);
383 s.push(poly);
384 i += 1;
385 }
386
387 let mut e = Vec::new();
389 i = 0;
390 while i < k {
391 let mut e_seed = Vec::new();
392 e_seed.extend_from_slice(sigma);
393 e_seed.push((k + i) as u8);
394 let hash = sha256(&e_seed);
395 let poly = KyberPoly::sample_cbd(hash.as_bytes(), eta1);
396 e.push(poly);
397 i += 1;
398 }
399
400 let mut s_hat = Vec::new();
402 for si in &s {
403 let mut si_ntt = si.clone();
404 si_ntt.ntt();
405 s_hat.push(si_ntt);
406 }
407
408 let mut e_hat = Vec::new();
410 for ei in &e {
411 let mut ei_ntt = ei.clone();
412 ei_ntt.ntt();
413 e_hat.push(ei_ntt);
414 }
415
416 let mut t_hat = Vec::new();
418 i = 0;
419 while i < k {
420 let mut ti = KyberPoly::zero();
421 let mut j = 0;
422 while j < k {
423 let product = a_hat[i * k + j].pointwise_mul(&s_hat[j]);
424 ti = ti.add(&product);
425 j += 1;
426 }
427 ti = ti.add(&e_hat[i]);
428 t_hat.push(ti);
429 i += 1;
430 }
431
432 let mut secret_data = Vec::new();
434 secret_data.extend_from_slice(&seed);
435 for si in &s_hat {
437 secret_data.extend_from_slice(&si.to_bytes());
438 }
439 let mut pk_bytes = Vec::new();
441 pk_bytes.extend_from_slice(rho);
442 for ti in &t_hat {
443 pk_bytes.extend_from_slice(&ti.to_bytes());
444 }
445 let pk_hash = sha256(&pk_bytes);
446 secret_data.extend_from_slice(pk_hash.as_bytes());
447 secret_data.extend_from_slice(rho);
449 for ti in &t_hat {
451 secret_data.extend_from_slice(&ti.to_bytes());
452 }
453
454 Ok(Self {
455 level,
456 secret: secret_data,
457 })
458 }
459
460 pub(crate) fn public_key(&self) -> KyberPublicKey {
462 let (k, _, _, _, _) = kyber_params(self.level);
463
464 let public_size = match self.level {
465 KyberLevel::Kyber512 => 800,
466 KyberLevel::Kyber768 => 1184,
467 KyberLevel::Kyber1024 => 1568,
468 };
469
470 let s_hat_bytes = k * KYBER_N * 12 / 8;
473 let rho_offset = 32 + s_hat_bytes + 32;
474
475 let mut public = Vec::with_capacity(public_size);
476
477 if rho_offset + 32 <= self.secret.len() {
479 public.extend_from_slice(&self.secret[rho_offset..rho_offset + 32]);
480 } else {
481 public.extend_from_slice(&[0u8; 32]);
482 }
483
484 let t_start = rho_offset + 32;
486 let mut i = 0;
487 while i < k {
488 let offset = t_start + i * (KYBER_N * 12 / 8);
489 if offset + KYBER_N * 12 / 8 <= self.secret.len() {
490 public.extend_from_slice(&self.secret[offset..offset + KYBER_N * 12 / 8]);
491 }
492 i += 1;
493 }
494
495 while public.len() < public_size {
497 public.push(0);
498 }
499 public.truncate(public_size);
500
501 KyberPublicKey {
502 level: self.level,
503 public,
504 }
505 }
506
507 pub(crate) fn decapsulate(
509 &self,
510 ciphertext: &KyberCiphertext,
511 ) -> CryptoResult<KyberSharedSecret> {
512 use crate::crypto::hash::sha256;
513
514 let (k, _, _, du, dv) = kyber_params(self.level);
515
516 let poly_bytes = KYBER_N * 12 / 8;
518 let mut s_hat = Vec::new();
519 let mut i = 0;
520 while i < k {
521 let offset = 32 + i * poly_bytes;
522 if offset + poly_bytes <= self.secret.len() {
523 s_hat.push(KyberPoly::from_bytes(
524 &self.secret[offset..offset + poly_bytes],
525 ));
526 } else {
527 s_hat.push(KyberPoly::zero());
528 }
529 i += 1;
530 }
531
532 let u_bytes_per_poly = KYBER_N * du / 8;
534 let v_bytes = KYBER_N * dv / 8;
535 let mut u = Vec::new();
536 i = 0;
537 while i < k {
538 let offset = i * u_bytes_per_poly;
539 if offset + u_bytes_per_poly <= ciphertext.bytes.len() {
540 let mut ui =
541 KyberPoly::decompress(&ciphertext.bytes[offset..offset + u_bytes_per_poly], du);
542 ui.ntt();
543 u.push(ui);
544 } else {
545 u.push(KyberPoly::zero());
546 }
547 i += 1;
548 }
549
550 let v_offset = k * u_bytes_per_poly;
551 let v = if v_offset + v_bytes <= ciphertext.bytes.len() {
552 KyberPoly::decompress(&ciphertext.bytes[v_offset..v_offset + v_bytes], dv)
553 } else {
554 KyberPoly::zero()
555 };
556
557 let mut su = KyberPoly::zero();
559 i = 0;
560 while i < k {
561 let product = s_hat[i].pointwise_mul(&u[i]);
562 su = su.add(&product);
563 i += 1;
564 }
565 su.inv_ntt();
566
567 let m_prime = v.sub(&su);
568
569 let mut msg = [0u8; 32];
571 let mut bi = 0;
572 while bi < 256 && bi / 8 < 32 {
573 let mut coeff = m_prime.coeffs[bi] as i32;
574 if coeff < 0 {
575 coeff += KYBER_Q as i32;
576 }
577 let bit = if coeff > KYBER_Q as i32 / 2 { 1u8 } else { 0u8 };
579 msg[bi / 8] |= bit << (bi % 8);
580 bi += 1;
581 }
582
583 let mut hash_input = Vec::new();
585 hash_input.extend_from_slice(&msg);
586 hash_input.extend_from_slice(&ciphertext.bytes);
587 let shared = sha256(&hash_input);
588
589 Ok(KyberSharedSecret {
590 bytes: *shared.as_bytes(),
591 })
592 }
593}
594
595impl KyberPublicKey {
596 pub(crate) fn encapsulate(&self) -> CryptoResult<(KyberCiphertext, KyberSharedSecret)> {
598 use crate::crypto::{hash::sha256, random::get_random};
599
600 let (k, _eta1, eta2, du, dv) = kyber_params(self.level);
601 let rng = get_random();
602
603 let mut msg = [0u8; 32];
605 rng.fill_bytes(&mut msg)?;
606
607 let pk_hash = sha256(&self.public);
609 let mut coin_input = Vec::new();
610 coin_input.extend_from_slice(&msg);
611 coin_input.extend_from_slice(pk_hash.as_bytes());
612 let coins = sha256(&coin_input);
613
614 let rho = if self.public.len() >= 32 {
616 &self.public[..32]
617 } else {
618 &[0u8; 32]
619 };
620 let poly_bytes = KYBER_N * 12 / 8;
621
622 let mut t_hat = Vec::new();
623 let mut i = 0;
624 while i < k {
625 let offset = 32 + i * poly_bytes;
626 if offset + poly_bytes <= self.public.len() {
627 t_hat.push(KyberPoly::from_bytes(
628 &self.public[offset..offset + poly_bytes],
629 ));
630 } else {
631 t_hat.push(KyberPoly::zero());
632 }
633 i += 1;
634 }
635
636 let mut a_hat = Vec::new();
638 i = 0;
639 while i < k * k {
640 let mut a_seed = Vec::new();
641 a_seed.extend_from_slice(rho);
642 a_seed.push((i % k) as u8);
643 a_seed.push((i / k) as u8);
644 let hash = sha256(&a_seed);
645 let mut poly = KyberPoly::sample_uniform(hash.as_bytes());
646 poly.ntt();
647 a_hat.push(poly);
648 i += 1;
649 }
650
651 let mut r = Vec::new();
653 i = 0;
654 while i < k {
655 let mut r_seed = Vec::new();
656 r_seed.extend_from_slice(coins.as_bytes());
657 r_seed.push(i as u8);
658 let hash = sha256(&r_seed);
659 let mut poly = KyberPoly::sample_cbd(hash.as_bytes(), eta2);
660 poly.ntt();
661 r.push(poly);
662 i += 1;
663 }
664
665 let mut e1 = Vec::new();
666 i = 0;
667 while i < k {
668 let mut e1_seed = Vec::new();
669 e1_seed.extend_from_slice(coins.as_bytes());
670 e1_seed.push((k + i) as u8);
671 let hash = sha256(&e1_seed);
672 e1.push(KyberPoly::sample_cbd(hash.as_bytes(), eta2));
673 i += 1;
674 }
675
676 let mut e2_seed = Vec::new();
677 e2_seed.extend_from_slice(coins.as_bytes());
678 e2_seed.push((2 * k) as u8);
679 let e2_hash = sha256(&e2_seed);
680 let e2 = KyberPoly::sample_cbd(e2_hash.as_bytes(), eta2);
681
682 let mut u = Vec::new();
684 i = 0;
685 while i < k {
686 let mut ui = KyberPoly::zero();
687 let mut j = 0;
688 while j < k {
689 let product = a_hat[j * k + i].pointwise_mul(&r[j]);
691 ui = ui.add(&product);
692 j += 1;
693 }
694 ui.inv_ntt();
695 ui = ui.add(&e1[i]);
696 u.push(ui);
697 i += 1;
698 }
699
700 let mut v = KyberPoly::zero();
702 i = 0;
703 while i < k {
704 let product = t_hat[i].pointwise_mul(&r[i]);
705 v = v.add(&product);
706 i += 1;
707 }
708 v.inv_ntt();
709 v = v.add(&e2);
710
711 let mut msg_poly = KyberPoly::zero();
713 let mut bi = 0;
714 while bi < 256 && bi / 8 < 32 {
715 let bit = (msg[bi / 8] >> (bi % 8)) & 1;
716 if bit == 1 {
717 msg_poly.coeffs[bi] = KYBER_Q.div_ceil(2) as i16;
718 }
719 bi += 1;
720 }
721 v = v.add(&msg_poly);
722
723 let mut ct_bytes = Vec::new();
725 for ui in &u {
726 ct_bytes.extend_from_slice(&ui.compress(du));
727 }
728 ct_bytes.extend_from_slice(&v.compress(dv));
729
730 let mut hash_input = Vec::new();
732 hash_input.extend_from_slice(&msg);
733 hash_input.extend_from_slice(&ct_bytes);
734 let shared = sha256(&hash_input);
735
736 Ok((
737 KyberCiphertext { bytes: ct_bytes },
738 KyberSharedSecret {
739 bytes: *shared.as_bytes(),
740 },
741 ))
742 }
743}
744
745impl KyberSharedSecret {
746 pub(crate) fn as_bytes(&self) -> &[u8; 32] {
748 &self.bytes
749 }
750}