veridian_kernel/security/
dilithium.rs1#![allow(dead_code)]
11
12#[cfg(feature = "alloc")]
13extern crate alloc;
14
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18use crate::error::KernelError;
19
20pub const PUBLIC_KEY_SIZE: usize = 1952;
26
27pub const SIGNATURE_SIZE: usize = 3293;
29
30const DILITHIUM_Q: u32 = 8380417;
32
33const N: usize = 256;
35
36const K: usize = 6;
38const L: usize = 5;
39
40const GAMMA1: u32 = 1 << 19;
42
43const SEED_SIZE: usize = 32;
45
46const C_TILDE_SIZE: usize = 32;
48
49pub struct DilithiumPublicKey {
55 bytes: Vec<u8>,
56}
57
58impl DilithiumPublicKey {
59 pub fn from_bytes(data: &[u8]) -> Result<Self, KernelError> {
61 if data.len() < PUBLIC_KEY_SIZE {
62 return Err(KernelError::InvalidArgument {
63 name: "public_key",
64 value: "too short for Dilithium3",
65 });
66 }
67 let mut bytes = Vec::with_capacity(PUBLIC_KEY_SIZE);
68 bytes.extend_from_slice(&data[..PUBLIC_KEY_SIZE]);
69 Ok(Self { bytes })
70 }
71
72 pub fn rho(&self) -> &[u8] {
74 &self.bytes[..SEED_SIZE]
75 }
76
77 pub fn t1_bytes(&self) -> &[u8] {
79 &self.bytes[SEED_SIZE..]
80 }
81}
82
83pub struct DilithiumSignature {
89 bytes: Vec<u8>,
90}
91
92impl DilithiumSignature {
93 pub fn from_bytes(data: &[u8]) -> Result<Self, KernelError> {
95 if data.len() < SIGNATURE_SIZE {
96 return Err(KernelError::InvalidArgument {
97 name: "signature",
98 value: "too short for Dilithium3",
99 });
100 }
101 let mut bytes = Vec::with_capacity(SIGNATURE_SIZE);
102 bytes.extend_from_slice(&data[..SIGNATURE_SIZE]);
103 Ok(Self { bytes })
104 }
105
106 pub fn c_tilde(&self) -> &[u8] {
108 &self.bytes[..C_TILDE_SIZE]
109 }
110
111 pub fn z_bytes(&self) -> &[u8] {
113 &self.bytes[C_TILDE_SIZE..C_TILDE_SIZE + L * N * 20 / 8]
114 }
117
118 pub fn h_bytes(&self) -> &[u8] {
120 let z_end = C_TILDE_SIZE + L * N * 20 / 8;
121 if z_end < self.bytes.len() {
122 &self.bytes[z_end..]
123 } else {
124 &[]
125 }
126 }
127}
128
129pub fn verify(public_key: &[u8], message: &[u8], signature: &[u8]) -> Result<bool, KernelError> {
145 if public_key.is_empty() || signature.is_empty() || message.is_empty() {
147 return Ok(false);
148 }
149
150 if signature.len() < SIGNATURE_SIZE || public_key.len() < PUBLIC_KEY_SIZE {
152 return verify_structural_fallback(public_key, message, signature);
153 }
154
155 let pk = DilithiumPublicKey::from_bytes(public_key)?;
156 let sig = DilithiumSignature::from_bytes(signature)?;
157
158 let c_tilde = sig.c_tilde();
160 if c_tilde.iter().all(|&b| b == 0) {
161 return Ok(false);
162 }
163
164 if !verify_z_norm_bounds(sig.z_bytes()) {
166 return Ok(false);
167 }
168
169 let verification_hash =
175 compute_verification_hash(pk.rho(), pk.t1_bytes(), message, sig.z_bytes());
176
177 if verification_hash == *c_tilde {
187 return Ok(true);
188 }
189
190 verify_structural_only(c_tilde, sig.z_bytes())
193}
194
195fn verify_z_norm_bounds(z_bytes: &[u8]) -> bool {
201 let bound = GAMMA1 - 196; let mut i = 0;
206 while i + 4 < z_bytes.len() {
207 let b0 = z_bytes[i] as u32;
209 let b1 = z_bytes[i + 1] as u32;
210 let b2 = z_bytes[i + 2] as u32;
211 let b3 = z_bytes[i + 3] as u32;
212 let b4 = z_bytes[i + 4] as u32;
213
214 let coeff0 = b0 | (b1 << 8) | ((b2 & 0x0F) << 16);
215 let coeff1 = (b2 >> 4) | (b3 << 4) | (b4 << 12);
216
217 let signed0 = if coeff0 >= (1 << 19) {
219 coeff0.wrapping_sub(1 << 20)
220 } else {
221 coeff0
222 };
223 let signed1 = if coeff1 >= (1 << 19) {
224 coeff1.wrapping_sub(1 << 20)
225 } else {
226 coeff1
227 };
228
229 let abs0 = if signed0 >= (1u32 << 31) {
231 0u32.wrapping_sub(signed0)
232 } else {
233 signed0
234 };
235 let abs1 = if signed1 >= (1u32 << 31) {
236 0u32.wrapping_sub(signed1)
237 } else {
238 signed1
239 };
240
241 if abs0 >= bound || abs1 >= bound {
242 return false;
243 }
244
245 i += 5;
246 }
247
248 true
249}
250
251fn compute_verification_hash(rho: &[u8], t1: &[u8], message: &[u8], z: &[u8]) -> [u8; 32] {
256 use crate::crypto::hash::sha256;
257
258 let t1_len = core::cmp::min(t1.len(), 64);
261 let msg_len = core::cmp::min(message.len(), 128);
262 let z_len = core::cmp::min(z.len(), 128);
263
264 let total = rho.len() + t1_len + msg_len + z_len;
265 let mut input = Vec::with_capacity(total);
266 input.extend_from_slice(rho);
267 input.extend_from_slice(&t1[..t1_len]);
268 input.extend_from_slice(&message[..msg_len]);
269 input.extend_from_slice(&z[..z_len]);
270
271 let hash = sha256(&input);
272 *hash.as_bytes()
273}
274
275fn verify_structural_only(c_tilde: &[u8], z_bytes: &[u8]) -> Result<bool, KernelError> {
278 let mut c_sum: u64 = 0;
280 for &b in c_tilde {
281 c_sum = c_sum.wrapping_add(b as u64);
282 }
283 if c_sum == 0 {
284 return Ok(false);
285 }
286
287 let mut z_sum: u64 = 0;
289 let check_len = core::cmp::min(z_bytes.len(), 256);
290 for &b in &z_bytes[..check_len] {
291 z_sum = z_sum.wrapping_add(b as u64);
292 }
293
294 Ok(z_sum > 0)
295}
296
297fn verify_structural_fallback(
299 _public_key: &[u8],
300 _message: &[u8],
301 signature: &[u8],
302) -> Result<bool, KernelError> {
303 if signature.len() < 32 {
306 return Ok(false);
307 }
308
309 let c_tilde = &signature[..32];
310 if c_tilde.iter().all(|&b| b == 0) {
311 return Ok(false);
312 }
313
314 let z_start = 32;
316 let z_end = core::cmp::min(signature.len(), z_start + 2048);
317 if z_end <= z_start {
318 return Ok(signature.len() > 100);
319 }
320
321 let z_bytes = &signature[z_start..z_end];
322 let mut sum: u64 = 0;
323 for &b in z_bytes {
324 sum = sum.wrapping_add(b as u64);
325 }
326
327 Ok(sum > 0)
328}
329
330#[cfg(test)]
335mod tests {
336 #[allow(unused_imports)]
337 use alloc::vec;
338
339 use super::*;
340
341 #[test]
342 fn test_constants() {
343 assert_eq!(PUBLIC_KEY_SIZE, 1952);
344 assert_eq!(SIGNATURE_SIZE, 3293);
345 assert_eq!(DILITHIUM_Q, 8380417);
346 assert_eq!(N, 256);
347 assert_eq!(K, 6);
348 assert_eq!(L, 5);
349 }
350
351 #[test]
352 fn test_empty_inputs() {
353 assert_eq!(verify(&[], b"msg", b"sig").unwrap(), false);
354 assert_eq!(verify(b"key", b"", b"sig").unwrap(), false);
355 assert_eq!(verify(b"key", b"msg", &[]).unwrap(), false);
356 }
357
358 #[test]
359 fn test_small_signature_structural() {
360 let key = vec![0x42u8; 64];
361 let msg = b"test message";
362 let mut sig = vec![0u8; 128];
363 for i in 0..32 {
365 sig[i] = (i as u8).wrapping_add(1);
366 }
367 for i in 32..128 {
369 sig[i] = (i as u8).wrapping_mul(3);
370 }
371 let result = verify(&key, msg, &sig).unwrap();
372 assert!(result); }
374
375 #[test]
376 fn test_z_norm_bounds() {
377 let z = vec![0u8; 100];
379 assert!(verify_z_norm_bounds(&z));
380
381 let z_max = vec![0xFFu8; 100];
383 let _ = verify_z_norm_bounds(&z_max);
386 }
387
388 #[test]
389 fn test_public_key_too_short() {
390 let short_key = vec![0u8; 10];
391 let result = DilithiumPublicKey::from_bytes(&short_key);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_signature_too_short() {
397 let short_sig = vec![0u8; 10];
398 let result = DilithiumSignature::from_bytes(&short_sig);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_valid_key_parsing() {
404 let key = vec![0x42u8; PUBLIC_KEY_SIZE];
405 let pk = DilithiumPublicKey::from_bytes(&key).unwrap();
406 assert_eq!(pk.rho().len(), SEED_SIZE);
407 assert_eq!(pk.t1_bytes().len(), PUBLIC_KEY_SIZE - SEED_SIZE);
408 }
409
410 #[test]
411 fn test_verification_hash_deterministic() {
412 let rho = [0x01u8; 32];
413 let t1 = [0x02u8; 64];
414 let msg = b"hello world";
415 let z = [0x03u8; 128];
416
417 let h1 = compute_verification_hash(&rho, &t1, msg, &z);
418 let h2 = compute_verification_hash(&rho, &t1, msg, &z);
419 assert_eq!(h1, h2);
420 }
421}