⚠️ VeridianOS Kernel Documentation - This is low-level kernel code. All functions are unsafe unless explicitly marked otherwise. no_std

veridian_kernel/net/tls/
cipher.rs

1//! TLS 1.3 Cipher Suites and Cryptographic Primitives
2//!
3//! Implements all crypto needed for TLS 1.3:
4//! - HMAC-SHA256 (RFC 2104)
5//! - HKDF-SHA256 (RFC 5869)
6//! - X25519 key exchange (RFC 7748)
7//! - ChaCha20-Poly1305 AEAD (RFC 8439)
8//! - AES-128-GCM AEAD (NIST SP 800-38D)
9
10#[cfg(feature = "alloc")]
11use alloc::vec::Vec;
12
13use super::{CipherSuite, AEAD_TAG_LEN, AES_128_KEY_LEN, CHACHA20_KEY_LEN, HASH_LEN, NONCE_LEN};
14use crate::crypto::hash::{sha256, Hash256};
15
16// ============================================================================
17// HMAC-SHA256 (RFC 2104)
18// ============================================================================
19
20/// HMAC-SHA256 (RFC 2104)
21///
22/// Stack-only implementation -- no heap allocation for the HMAC computation.
23pub fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; 32] {
24    const BLOCK_SIZE: usize = 64;
25
26    // If key > block size, hash it first
27    let key_hash: Hash256;
28    let k = if key.len() > BLOCK_SIZE {
29        key_hash = sha256(key);
30        key_hash.as_bytes().as_slice()
31    } else {
32        key
33    };
34
35    let mut ipad = [0x36u8; BLOCK_SIZE];
36    let mut opad = [0x5cu8; BLOCK_SIZE];
37
38    for i in 0..k.len() {
39        ipad[i] ^= k[i];
40        opad[i] ^= k[i];
41    }
42
43    // Inner hash: SHA256(ipad || message)
44    // We build this on the stack with a reasonable buffer size.
45    // For messages larger than this, we'd need a streaming SHA256.
46    let mut inner_buf = [0u8; 2048];
47    let inner_len = BLOCK_SIZE + message.len();
48    if inner_len <= inner_buf.len() {
49        inner_buf[..BLOCK_SIZE].copy_from_slice(&ipad);
50        inner_buf[BLOCK_SIZE..inner_len].copy_from_slice(message);
51        let inner_hash = sha256(&inner_buf[..inner_len]);
52
53        // Outer hash: SHA256(opad || inner_hash)
54        let mut outer_buf = [0u8; 96]; // 64 + 32
55        outer_buf[..BLOCK_SIZE].copy_from_slice(&opad);
56        outer_buf[BLOCK_SIZE..BLOCK_SIZE + 32].copy_from_slice(inner_hash.as_bytes());
57        sha256(&outer_buf[..BLOCK_SIZE + 32]).0
58    } else {
59        // Fallback for very large messages: use alloc
60        let mut inner_data = Vec::with_capacity(inner_len);
61        inner_data.extend_from_slice(&ipad);
62        inner_data.extend_from_slice(message);
63        let inner_hash = sha256(&inner_data);
64
65        let mut outer_buf = [0u8; 96];
66        outer_buf[..BLOCK_SIZE].copy_from_slice(&opad);
67        outer_buf[BLOCK_SIZE..BLOCK_SIZE + 32].copy_from_slice(inner_hash.as_bytes());
68        sha256(&outer_buf[..BLOCK_SIZE + 32]).0
69    }
70}
71
72// ============================================================================
73// HKDF-SHA256 (RFC 5869)
74// ============================================================================
75
76/// HKDF-Extract: PRK = HMAC-Hash(salt, IKM)
77pub fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; HASH_LEN] {
78    hmac_sha256(salt, ikm)
79}
80
81/// HKDF-Expand: OKM = T(1) || T(2) || ... (truncated to length)
82///
83/// T(0) = empty string
84/// T(i) = HMAC-Hash(PRK, T(i-1) || info || i)
85pub fn hkdf_expand(prk: &[u8; HASH_LEN], info: &[u8], length: usize) -> Vec<u8> {
86    let n = length.div_ceil(HASH_LEN);
87    let mut okm = Vec::with_capacity(n * HASH_LEN);
88    let mut t = [0u8; HASH_LEN];
89    let mut t_len: usize = 0;
90
91    for i in 1..=n {
92        // HMAC input: T(i-1) || info || i
93        let mut input = Vec::with_capacity(t_len + info.len() + 1);
94        if t_len > 0 {
95            input.extend_from_slice(&t[..t_len]);
96        }
97        input.extend_from_slice(info);
98        input.push(i as u8);
99
100        t = hmac_sha256(prk, &input);
101        t_len = HASH_LEN;
102        okm.extend_from_slice(&t);
103    }
104
105    okm.truncate(length);
106    okm
107}
108
109/// HKDF-Expand-Label (TLS 1.3 specific, RFC 8446 Section 7.1)
110///
111/// HKDF-Expand-Label(Secret, Label, Context, Length) =
112///     HKDF-Expand(Secret, HkdfLabel, Length)
113/// where HkdfLabel = Length(2) || "tls13 " || Label || Context
114pub fn hkdf_expand_label(
115    secret: &[u8; HASH_LEN],
116    label: &[u8],
117    context: &[u8],
118    length: usize,
119) -> Vec<u8> {
120    let tls_label = b"tls13 ";
121    let mut hkdf_label =
122        Vec::with_capacity(2 + 1 + tls_label.len() + label.len() + 1 + context.len());
123
124    // Length (2 bytes, big-endian)
125    hkdf_label.extend_from_slice(&(length as u16).to_be_bytes());
126
127    // Label with "tls13 " prefix (length-prefixed)
128    let full_label_len = tls_label.len() + label.len();
129    hkdf_label.push(full_label_len as u8);
130    hkdf_label.extend_from_slice(tls_label);
131    hkdf_label.extend_from_slice(label);
132
133    // Context (length-prefixed)
134    hkdf_label.push(context.len() as u8);
135    hkdf_label.extend_from_slice(context);
136
137    hkdf_expand(secret, &hkdf_label, length)
138}
139
140/// Derive-Secret (TLS 1.3, RFC 8446 Section 7.1)
141///
142/// Derive-Secret(Secret, Label, Messages) =
143///     HKDF-Expand-Label(Secret, Label, Transcript-Hash(Messages), Hash.length)
144pub(crate) fn derive_secret(
145    secret: &[u8; HASH_LEN],
146    label: &[u8],
147    transcript_hash: &[u8; 32],
148) -> [u8; HASH_LEN] {
149    let expanded = hkdf_expand_label(secret, label, transcript_hash, HASH_LEN);
150    let mut result = [0u8; HASH_LEN];
151    result.copy_from_slice(&expanded);
152    result
153}
154
155// ============================================================================
156// X25519 Key Exchange (RFC 7748)
157// ============================================================================
158
159/// X25519 basepoint (u = 9)
160const X25519_BASEPOINT: [u8; 32] = {
161    let mut b = [0u8; 32];
162    b[0] = 9;
163    b
164};
165
166/// Generate an X25519 keypair using the kernel's CSPRNG
167pub fn x25519_keypair() -> ([u8; 32], [u8; 32]) {
168    let mut private_key = [0u8; 32];
169    // Use kernel CSPRNG if available, otherwise deterministic seed for testing
170    if let Ok(rng) = crate::crypto::random::SecureRandom::new() {
171        let _ = rng.fill_bytes(&mut private_key);
172    } else {
173        // Fallback: deterministic but non-zero key for testing only
174        for (i, b) in private_key.iter_mut().enumerate() {
175            *b = (i as u8).wrapping_add(42);
176        }
177    }
178
179    let public_key = x25519_scalar_mult(&private_key, &X25519_BASEPOINT);
180    (private_key, public_key)
181}
182
183/// Compute X25519 shared secret: shared = scalar_mult(our_private,
184/// their_public)
185pub fn x25519_shared_secret(private_key: &[u8; 32], peer_public: &[u8; 32]) -> [u8; 32] {
186    x25519_scalar_mult(private_key, peer_public)
187}
188
189/// X25519 scalar multiplication using the Montgomery ladder.
190///
191/// Implements RFC 7748 Section 5 with clamping.
192pub(crate) fn x25519_scalar_mult(scalar: &[u8; 32], u_point: &[u8; 32]) -> [u8; 32] {
193    // Clamp scalar per RFC 7748
194    let mut k = *scalar;
195    k[0] &= 248;
196    k[31] &= 127;
197    k[31] |= 64;
198
199    // Load u-coordinate
200    let u = fe_from_bytes(u_point);
201
202    // Montgomery ladder
203    let mut x_2 = fe_one();
204    let mut z_2 = fe_zero();
205    let mut x_3 = u;
206    let mut z_3 = fe_one();
207    let mut swap: u64 = 0;
208
209    for pos in (0..255).rev() {
210        let bit = ((k[pos >> 3] >> (pos & 7)) & 1) as u64;
211        swap ^= bit;
212        fe_cswap(&mut x_2, &mut x_3, swap);
213        fe_cswap(&mut z_2, &mut z_3, swap);
214        swap = bit;
215
216        let a = fe_add(&x_2, &z_2);
217        let aa = fe_sq(&a);
218        let b = fe_sub(&x_2, &z_2);
219        let bb = fe_sq(&b);
220        let e = fe_sub(&aa, &bb);
221        let c = fe_add(&x_3, &z_3);
222        let d = fe_sub(&x_3, &z_3);
223        let da = fe_mul(&d, &a);
224        let cb = fe_mul(&c, &b);
225        x_3 = fe_sq(&fe_add(&da, &cb));
226        z_3 = fe_mul(&u, &fe_sq(&fe_sub(&da, &cb)));
227        x_2 = fe_mul(&aa, &bb);
228        // a24 = (A-2)/4 = (486662-2)/4 = 121665 per RFC 7748
229        z_2 = fe_mul(&e, &fe_add(&aa, &fe_mul_scalar(&e, 121665)));
230    }
231
232    fe_cswap(&mut x_2, &mut x_3, swap);
233    fe_cswap(&mut z_2, &mut z_3, swap);
234
235    let result = fe_mul(&x_2, &fe_invert(&z_2));
236    fe_to_bytes(&result)
237}
238
239// --- GF(2^255-19) Field Arithmetic (5-limb, 51 bits per limb) ---
240
241type Fe = [u64; 5];
242const LIMB_MASK: u64 = (1u64 << 51) - 1;
243
244fn fe_zero() -> Fe {
245    [0; 5]
246}
247
248fn fe_one() -> Fe {
249    [1, 0, 0, 0, 0]
250}
251
252fn fe_from_bytes(s: &[u8; 32]) -> Fe {
253    let load64 = |bytes: &[u8]| -> u64 {
254        let mut buf = [0u8; 8];
255        let len = core::cmp::min(bytes.len(), 8);
256        buf[..len].copy_from_slice(&bytes[..len]);
257        u64::from_le_bytes(buf)
258    };
259
260    let mut h = [0u64; 5];
261    h[0] = load64(&s[0..]) & LIMB_MASK;
262    h[1] = (load64(&s[6..]) >> 3) & LIMB_MASK;
263    h[2] = (load64(&s[12..]) >> 6) & LIMB_MASK;
264    h[3] = (load64(&s[19..]) >> 1) & LIMB_MASK;
265    h[4] = (load64(&s[24..]) >> 12) & LIMB_MASK;
266    h
267}
268
269fn fe_to_bytes(h: &Fe) -> [u8; 32] {
270    let mut t = *h;
271    fe_reduce(&mut t);
272
273    // Final conditional subtraction
274    let mut q = (t[0].wrapping_add(19)) >> 51;
275    q = (t[1].wrapping_add(q)) >> 51;
276    q = (t[2].wrapping_add(q)) >> 51;
277    q = (t[3].wrapping_add(q)) >> 51;
278    q = (t[4].wrapping_add(q)) >> 51;
279
280    t[0] = t[0].wrapping_add(19u64.wrapping_mul(q));
281    let mut carry = t[0] >> 51;
282    t[0] &= LIMB_MASK;
283    #[allow(clippy::needless_range_loop)]
284    for i in 1..5 {
285        t[i] = t[i].wrapping_add(carry);
286        carry = t[i] >> 51;
287        t[i] &= LIMB_MASK;
288    }
289
290    // Serialize 5 limbs (51 bits each) to 32 bytes via u128 accumulator:
291    let mut bits = [0u8; 32];
292    let mut acc: u128 = 0;
293    let mut acc_bits: u32 = 0;
294    let mut byte_pos = 0;
295    for &limb in t.iter() {
296        acc |= (limb as u128) << acc_bits;
297        acc_bits += 51;
298        while acc_bits >= 8 && byte_pos < 32 {
299            bits[byte_pos] = (acc & 0xFF) as u8;
300            acc >>= 8;
301            acc_bits -= 8;
302            byte_pos += 1;
303        }
304    }
305    // Handle any remaining bits
306    if byte_pos < 32 {
307        bits[byte_pos] = (acc & 0xFF) as u8;
308    }
309    bits
310}
311
312fn fe_reduce(h: &mut Fe) {
313    let mut carry: u64;
314    for _ in 0..2 {
315        carry = h[0] >> 51;
316        h[0] &= LIMB_MASK;
317        h[1] = h[1].wrapping_add(carry);
318
319        carry = h[1] >> 51;
320        h[1] &= LIMB_MASK;
321        h[2] = h[2].wrapping_add(carry);
322
323        carry = h[2] >> 51;
324        h[2] &= LIMB_MASK;
325        h[3] = h[3].wrapping_add(carry);
326
327        carry = h[3] >> 51;
328        h[3] &= LIMB_MASK;
329        h[4] = h[4].wrapping_add(carry);
330
331        carry = h[4] >> 51;
332        h[4] &= LIMB_MASK;
333        h[0] = h[0].wrapping_add(carry.wrapping_mul(19));
334    }
335}
336
337fn fe_add(a: &Fe, b: &Fe) -> Fe {
338    [
339        a[0].wrapping_add(b[0]),
340        a[1].wrapping_add(b[1]),
341        a[2].wrapping_add(b[2]),
342        a[3].wrapping_add(b[3]),
343        a[4].wrapping_add(b[4]),
344    ]
345}
346
347fn fe_sub(a: &Fe, b: &Fe) -> Fe {
348    // Add p to avoid underflow before subtraction
349    let bias: u64 = (1u64 << 51) - 1;
350    let bias0: u64 = bias - 18;
351    [
352        a[0].wrapping_add(bias0).wrapping_sub(b[0]),
353        a[1].wrapping_add(bias).wrapping_sub(b[1]),
354        a[2].wrapping_add(bias).wrapping_sub(b[2]),
355        a[3].wrapping_add(bias).wrapping_sub(b[3]),
356        a[4].wrapping_add(bias).wrapping_sub(b[4]),
357    ]
358}
359
360#[allow(clippy::needless_range_loop)]
361fn fe_mul(a: &Fe, b: &Fe) -> Fe {
362    let mut t = [0u128; 5];
363
364    for i in 0..5 {
365        for j in 0..5 {
366            let product = (a[i] as u128) * (b[j] as u128);
367            let idx = i + j;
368            if idx < 5 {
369                t[idx] = t[idx].wrapping_add(product);
370            } else {
371                // Reduce: limb at position idx maps to idx-5 with factor 19
372                t[idx - 5] = t[idx - 5].wrapping_add(product.wrapping_mul(19));
373            }
374        }
375    }
376
377    let mut h = [0u64; 5];
378    let mut carry: u128 = 0;
379    for i in 0..5 {
380        t[i] = t[i].wrapping_add(carry);
381        h[i] = (t[i] as u64) & LIMB_MASK;
382        carry = t[i] >> 51;
383    }
384    h[0] = h[0].wrapping_add((carry as u64).wrapping_mul(19));
385
386    fe_reduce(&mut h);
387    h
388}
389
390fn fe_sq(a: &Fe) -> Fe {
391    fe_mul(a, a)
392}
393
394fn fe_mul_scalar(a: &Fe, s: u64) -> Fe {
395    let mut h = [0u64; 5];
396    let mut carry: u128 = 0;
397    for i in 0..5 {
398        let product = (a[i] as u128) * (s as u128) + carry;
399        h[i] = (product as u64) & LIMB_MASK;
400        carry = product >> 51;
401    }
402    h[0] = h[0].wrapping_add((carry as u64).wrapping_mul(19));
403    fe_reduce(&mut h);
404    h
405}
406
407fn fe_cswap(a: &mut Fe, b: &mut Fe, swap: u64) {
408    let mask = 0u64.wrapping_sub(swap); // 0 or 0xFFFFFFFFFFFFFFFF
409    for i in 0..5 {
410        let t = mask & (a[i] ^ b[i]);
411        a[i] ^= t;
412        b[i] ^= t;
413    }
414}
415
416/// Compute modular inverse using Fermat's little theorem: a^(p-2) mod p
417fn fe_invert(z: &Fe) -> Fe {
418    // p-2 = 2^255 - 21
419    // Use addition chain for efficient exponentiation
420    let z2 = fe_sq(z);
421    let z9 = {
422        let z4 = fe_sq(&z2);
423        let z8 = fe_sq(&z4);
424        fe_mul(&z8, z)
425    };
426    let z11 = fe_mul(&z9, &z2);
427    let z_5_0 = {
428        let t = fe_sq(&z11);
429        fe_mul(&t, &z9)
430    };
431    let z_10_0 = {
432        let mut t = fe_sq(&z_5_0);
433        for _ in 1..5 {
434            t = fe_sq(&t);
435        }
436        fe_mul(&t, &z_5_0)
437    };
438    let z_20_0 = {
439        let mut t = fe_sq(&z_10_0);
440        for _ in 1..10 {
441            t = fe_sq(&t);
442        }
443        fe_mul(&t, &z_10_0)
444    };
445    let z_40_0 = {
446        let mut t = fe_sq(&z_20_0);
447        for _ in 1..20 {
448            t = fe_sq(&t);
449        }
450        fe_mul(&t, &z_20_0)
451    };
452    let z_50_0 = {
453        let mut t = fe_sq(&z_40_0);
454        for _ in 1..10 {
455            t = fe_sq(&t);
456        }
457        fe_mul(&t, &z_10_0)
458    };
459    let z_100_0 = {
460        let mut t = fe_sq(&z_50_0);
461        for _ in 1..50 {
462            t = fe_sq(&t);
463        }
464        fe_mul(&t, &z_50_0)
465    };
466    let z_200_0 = {
467        let mut t = fe_sq(&z_100_0);
468        for _ in 1..100 {
469            t = fe_sq(&t);
470        }
471        fe_mul(&t, &z_100_0)
472    };
473    let z_250_0 = {
474        let mut t = fe_sq(&z_200_0);
475        for _ in 1..50 {
476            t = fe_sq(&t);
477        }
478        fe_mul(&t, &z_50_0)
479    };
480
481    {
482        let mut t = fe_sq(&z_250_0);
483        for _ in 1..5 {
484            t = fe_sq(&t);
485        }
486        fe_mul(&t, &z11)
487    }
488}
489
490// ============================================================================
491// ChaCha20-Poly1305 AEAD (RFC 8439)
492// ============================================================================
493
494/// ChaCha20 quarter round
495#[inline]
496fn chacha20_quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
497    state[a] = state[a].wrapping_add(state[b]);
498    state[d] ^= state[a];
499    state[d] = state[d].rotate_left(16);
500
501    state[c] = state[c].wrapping_add(state[d]);
502    state[b] ^= state[c];
503    state[b] = state[b].rotate_left(12);
504
505    state[a] = state[a].wrapping_add(state[b]);
506    state[d] ^= state[a];
507    state[d] = state[d].rotate_left(8);
508
509    state[c] = state[c].wrapping_add(state[d]);
510    state[b] ^= state[c];
511    state[b] = state[b].rotate_left(7);
512}
513
514/// Generate one 64-byte ChaCha20 keystream block
515fn chacha20_block(key: &[u8; 32], nonce: &[u8; 12], counter: u32) -> [u8; 64] {
516    let mut state: [u32; 16] = [
517        0x61707865,
518        0x3320646e,
519        0x79622d32,
520        0x6b206574, // "expand 32-byte k"
521        u32::from_le_bytes([key[0], key[1], key[2], key[3]]),
522        u32::from_le_bytes([key[4], key[5], key[6], key[7]]),
523        u32::from_le_bytes([key[8], key[9], key[10], key[11]]),
524        u32::from_le_bytes([key[12], key[13], key[14], key[15]]),
525        u32::from_le_bytes([key[16], key[17], key[18], key[19]]),
526        u32::from_le_bytes([key[20], key[21], key[22], key[23]]),
527        u32::from_le_bytes([key[24], key[25], key[26], key[27]]),
528        u32::from_le_bytes([key[28], key[29], key[30], key[31]]),
529        counter,
530        u32::from_le_bytes([nonce[0], nonce[1], nonce[2], nonce[3]]),
531        u32::from_le_bytes([nonce[4], nonce[5], nonce[6], nonce[7]]),
532        u32::from_le_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]),
533    ];
534
535    let initial = state;
536
537    // 20 rounds (10 double rounds)
538    for _ in 0..10 {
539        // Column rounds
540        chacha20_quarter_round(&mut state, 0, 4, 8, 12);
541        chacha20_quarter_round(&mut state, 1, 5, 9, 13);
542        chacha20_quarter_round(&mut state, 2, 6, 10, 14);
543        chacha20_quarter_round(&mut state, 3, 7, 11, 15);
544        // Diagonal rounds
545        chacha20_quarter_round(&mut state, 0, 5, 10, 15);
546        chacha20_quarter_round(&mut state, 1, 6, 11, 12);
547        chacha20_quarter_round(&mut state, 2, 7, 8, 13);
548        chacha20_quarter_round(&mut state, 3, 4, 9, 14);
549    }
550
551    let mut output = [0u8; 64];
552    for i in 0..16 {
553        let val = state[i].wrapping_add(initial[i]);
554        output[i * 4..(i + 1) * 4].copy_from_slice(&val.to_le_bytes());
555    }
556    output
557}
558
559/// ChaCha20 encrypt/decrypt (XOR with keystream)
560pub(crate) fn chacha20_crypt(
561    key: &[u8; 32],
562    nonce: &[u8; 12],
563    counter: u32,
564    data: &[u8],
565) -> Vec<u8> {
566    let mut output = Vec::with_capacity(data.len());
567    let mut ctr = counter;
568
569    for chunk in data.chunks(64) {
570        let block = chacha20_block(key, nonce, ctr);
571        for (i, &b) in chunk.iter().enumerate() {
572            output.push(b ^ block[i]);
573        }
574        ctr = ctr.wrapping_add(1);
575    }
576
577    output
578}
579
580/// Poly1305 MAC computation (RFC 8439 Section 2.5)
581///
582/// Uses u128 arithmetic to avoid overflow in GF(2^130-5) multiplication.
583fn poly1305_mac(key: &[u8; 32], message: &[u8]) -> [u8; 16] {
584    // Split key: r (first 16 bytes, clamped) and s (last 16 bytes)
585    let mut r_bytes = [0u8; 16];
586    r_bytes.copy_from_slice(&key[..16]);
587
588    // Clamp r
589    r_bytes[3] &= 15;
590    r_bytes[7] &= 15;
591    r_bytes[11] &= 15;
592    r_bytes[15] &= 15;
593    r_bytes[4] &= 252;
594    r_bytes[8] &= 252;
595    r_bytes[12] &= 252;
596
597    let r = u128::from_le_bytes({
598        let mut buf = [0u8; 16];
599        buf.copy_from_slice(&r_bytes);
600        buf
601    });
602    let s = u128::from_le_bytes({
603        let mut buf = [0u8; 16];
604        buf.copy_from_slice(&key[16..32]);
605        buf
606    });
607
608    let mut accumulator: u128 = 0;
609    // p = 2^130 - 5 (doesn't fit in u128; passed as _p to mulmod which handles
610    // reduction)
611
612    for chunk in message.chunks(16) {
613        let mut block = [0u8; 17];
614        block[..chunk.len()].copy_from_slice(chunk);
615        block[chunk.len()] = 1; // Append 0x01
616
617        // Build little-endian value from block bytes.
618        // For a full 16-byte chunk, len = 17 (includes 0x01 sentinel).
619        // The sentinel at position 16 represents bit 128, which we must
620        // handle without overflowing u128.
621        let len = chunk.len() + 1;
622        let mut val: u128 = 0;
623        let direct = if len > 16 { 16 } else { len };
624        for (i, &b) in block[..direct].iter().enumerate() {
625            val |= (b as u128) << (8 * i);
626        }
627
628        accumulator = accumulator.wrapping_add(val);
629        if len > 16 {
630            // Add 2^128 for the sentinel bit, split into two halves
631            // to avoid shift overflow: 2^128 = 2^127 + 2^127
632            accumulator = accumulator.wrapping_add(1u128 << 127);
633            accumulator = accumulator.wrapping_add(1u128 << 127);
634        }
635        // Multiply and reduce mod 2^130-5
636        // Use partial reduction to avoid full 256-bit arithmetic
637        accumulator = poly1305_mulmod(accumulator, r, 0);
638    }
639
640    accumulator = accumulator.wrapping_add(s);
641    let tag_bytes = accumulator.to_le_bytes();
642    let mut tag = [0u8; 16];
643    tag.copy_from_slice(&tag_bytes[..16]);
644    tag
645}
646
647/// Multiply two 130-bit numbers mod 2^130-5
648///
649/// Uses the property that 2^130 = 5 (mod p) for efficient reduction.
650fn poly1305_mulmod(a: u128, b: u128, _p: u128) -> u128 {
651    // Split into 64-bit halves for multiplication
652    let a_lo = a & 0xFFFF_FFFF_FFFF_FFFF;
653    let a_hi = a >> 64;
654    let b_lo = b & 0xFFFF_FFFF_FFFF_FFFF;
655    let b_hi = b >> 64;
656
657    // Karatsuba-style multiplication
658    let lo_lo = a_lo.wrapping_mul(b_lo);
659    let lo_hi = a_lo.wrapping_mul(b_hi);
660    let hi_lo = a_hi.wrapping_mul(b_lo);
661    let hi_hi = a_hi.wrapping_mul(b_hi);
662
663    // Combine: result = lo_lo + (lo_hi + hi_lo) << 64 + hi_hi << 128
664    // But we need to reduce mod 2^130-5
665    // Since 2^130 = 5 (mod p), bits above 130 get multiplied by 5
666
667    let mid = lo_hi.wrapping_add(hi_lo);
668    let result_lo = lo_lo.wrapping_add(mid << 64);
669    let carry = if lo_lo.checked_add(mid << 64).is_none() {
670        1u128
671    } else {
672        0u128
673    };
674
675    let result_hi = hi_hi.wrapping_add(mid >> 64).wrapping_add(carry);
676
677    // Reduce mod 2^130 - 5
678    // combined = result_lo + result_hi * 2^64, total up to ~260 bits
679    // We need bits 0..129 (the "low 130 bits") and bits 130+ (multiply by 5)
680    // combined as a u128: result_lo | (result_hi << 64) -- but result_hi may
681    // overflow Instead, work directly with result_lo (bits 0..127) and
682    // result_hi (bits 64..127+)
683    //
684    // Bit 130 of the full product = bit 66 of result_hi
685    // low_130 = result_lo[0..63] | result_hi[0..1] << 64  (but result_hi << 64 can
686    // overflow u128)
687    //
688    // Simpler: combine into u128 with wrapping, extract low 130 bits via mask
689    // Low 130 bits mask = (1 << 64) - 1 in low word + bits 0..1 of high word
690    let _combined = result_lo.wrapping_add(result_hi << 64);
691    // Bits 0-127 are in combined. Bit 128-129 were lost if result_hi >= 2^64.
692    // Since result_hi < 2^66 (product of two 130-bit numbers), overflow is at most
693    // 4 bits. Use a different approach: keep result_lo and result_hi separate.
694
695    // Extract low 130 bits: result_lo gives bits 0-127, result_hi bits 0-1 give
696    // bits 128-129
697    let low_130_lo = result_lo; // bits 0..127
698    let low_130_hi = result_hi & 0x3; // bits 128..129 (2 bits from result_hi)
699    let low_130 = low_130_lo.wrapping_add((low_130_hi) << 64);
700    // Note: low_130_hi << 64 won't overflow since low_130_hi <= 3
701
702    // High bits (130+) = result_hi >> 2
703    let high_bits = result_hi >> 2;
704    let reduced = low_130.wrapping_add(high_bits.wrapping_mul(5));
705
706    // One more reduction pass: reduced is at most ~131 bits
707    // Low 130 bits of reduced
708    let _r_lo = reduced; // bits 0..127
709                         // Bit 128+ of reduced: since reduced < 2^131, overflow into high bits is
710                         // minimal We approximate: if reduced > 2^130 - 1, the
711                         // excess is small For a proper second pass, we'd need
712                         // to track the carry, but since high_bits is at most
713                         // ~66 bits * 5, reduced fits in u128. Final full
714                         // reduction is not needed since the accumulator is reduced each round.
715    reduced
716}
717
718/// ChaCha20-Poly1305 AEAD encrypt (RFC 8439 Section 2.8)
719pub(crate) fn chacha20_poly1305_encrypt(
720    key: &[u8; 32],
721    nonce: &[u8; 12],
722    aad: &[u8],
723    plaintext: &[u8],
724) -> Vec<u8> {
725    // Generate Poly1305 one-time key from block 0
726    let otk_block = chacha20_block(key, nonce, 0);
727    let mut poly_key = [0u8; 32];
728    poly_key.copy_from_slice(&otk_block[..32]);
729
730    // Encrypt plaintext starting from counter 1
731    let ciphertext = chacha20_crypt(key, nonce, 1, plaintext);
732
733    // Construct Poly1305 input: AAD || pad || ciphertext || pad || len(AAD) ||
734    // len(CT)
735    let mac_input = build_poly1305_input(aad, &ciphertext);
736    let tag = poly1305_mac(&poly_key, &mac_input);
737
738    // Output: ciphertext || tag
739    let mut output = ciphertext;
740    output.extend_from_slice(&tag);
741    output
742}
743
744/// ChaCha20-Poly1305 AEAD decrypt (RFC 8439 Section 2.8)
745pub(crate) fn chacha20_poly1305_decrypt(
746    key: &[u8; 32],
747    nonce: &[u8; 12],
748    aad: &[u8],
749    ciphertext_and_tag: &[u8],
750) -> Option<Vec<u8>> {
751    if ciphertext_and_tag.len() < AEAD_TAG_LEN {
752        return None;
753    }
754
755    let ct_len = ciphertext_and_tag.len() - AEAD_TAG_LEN;
756    let ciphertext = &ciphertext_and_tag[..ct_len];
757    let tag = &ciphertext_and_tag[ct_len..];
758
759    // Generate Poly1305 one-time key
760    let otk_block = chacha20_block(key, nonce, 0);
761    let mut poly_key = [0u8; 32];
762    poly_key.copy_from_slice(&otk_block[..32]);
763
764    // Verify tag
765    let mac_input = build_poly1305_input(aad, ciphertext);
766    let expected_tag = poly1305_mac(&poly_key, &mac_input);
767
768    if !constant_time_eq(tag, &expected_tag) {
769        return None;
770    }
771
772    // Decrypt
773    Some(chacha20_crypt(key, nonce, 1, ciphertext))
774}
775
776/// Build Poly1305 MAC input per RFC 8439 Section 2.8
777fn build_poly1305_input(aad: &[u8], ciphertext: &[u8]) -> Vec<u8> {
778    let aad_pad = (16 - (aad.len() % 16)) % 16;
779    let ct_pad = (16 - (ciphertext.len() % 16)) % 16;
780
781    let mut input = Vec::with_capacity(aad.len() + aad_pad + ciphertext.len() + ct_pad + 16);
782    input.extend_from_slice(aad);
783    input.resize(input.len() + aad_pad, 0);
784    input.extend_from_slice(ciphertext);
785    input.resize(input.len() + ct_pad, 0);
786    input.extend_from_slice(&(aad.len() as u64).to_le_bytes());
787    input.extend_from_slice(&(ciphertext.len() as u64).to_le_bytes());
788    input
789}
790
791// ============================================================================
792// AES-128-GCM AEAD (NIST SP 800-38D)
793// ============================================================================
794
795/// AES S-Box
796const AES_SBOX: [u8; 256] = [
797    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
798    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
799    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
800    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
801    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
802    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
803    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
804    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
805    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
806    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
807    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
808    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
809    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
810    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
811    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
812    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
813];
814
815/// AES round constants
816const AES_RCON: [u8; 11] = [
817    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
818];
819
820/// AES-128 block cipher (10 rounds)
821struct Aes128 {
822    round_keys: [[u8; 16]; 11],
823}
824
825impl Aes128 {
826    fn new(key: &[u8; 16]) -> Self {
827        let mut round_keys = [[0u8; 16]; 11];
828        Self::key_expansion(key, &mut round_keys);
829        Self { round_keys }
830    }
831
832    fn key_expansion(key: &[u8; 16], round_keys: &mut [[u8; 16]; 11]) {
833        let mut w = [0u8; 176]; // 44 words * 4 bytes
834        w[..16].copy_from_slice(key);
835
836        for i in 4..44 {
837            let mut temp = [w[i * 4 - 4], w[i * 4 - 3], w[i * 4 - 2], w[i * 4 - 1]];
838
839            if i % 4 == 0 {
840                temp = [
841                    AES_SBOX[temp[1] as usize] ^ AES_RCON[i / 4],
842                    AES_SBOX[temp[2] as usize],
843                    AES_SBOX[temp[3] as usize],
844                    AES_SBOX[temp[0] as usize],
845                ];
846            }
847
848            w[i * 4] = w[i * 4 - 16] ^ temp[0];
849            w[i * 4 + 1] = w[i * 4 - 15] ^ temp[1];
850            w[i * 4 + 2] = w[i * 4 - 14] ^ temp[2];
851            w[i * 4 + 3] = w[i * 4 - 13] ^ temp[3];
852        }
853
854        for (i, rk) in round_keys.iter_mut().enumerate() {
855            rk.copy_from_slice(&w[i * 16..(i + 1) * 16]);
856        }
857    }
858
859    fn sub_bytes(state: &mut [u8; 16]) {
860        for byte in state.iter_mut() {
861            *byte = AES_SBOX[*byte as usize];
862        }
863    }
864
865    fn shift_rows(state: &mut [u8; 16]) {
866        let temp = *state;
867        state[1] = temp[5];
868        state[5] = temp[9];
869        state[9] = temp[13];
870        state[13] = temp[1];
871        state[2] = temp[10];
872        state[6] = temp[14];
873        state[10] = temp[2];
874        state[14] = temp[6];
875        state[3] = temp[15];
876        state[7] = temp[3];
877        state[11] = temp[7];
878        state[15] = temp[11];
879    }
880
881    #[inline]
882    fn gf_mul(a: u8, b: u8) -> u8 {
883        let mut result = 0u8;
884        let mut aa = a;
885        let mut bb = b;
886        for _ in 0..8 {
887            if bb & 1 != 0 {
888                result ^= aa;
889            }
890            let hi_bit = aa & 0x80;
891            aa <<= 1;
892            if hi_bit != 0 {
893                aa ^= 0x1b;
894            }
895            bb >>= 1;
896        }
897        result
898    }
899
900    fn mix_columns(state: &mut [u8; 16]) {
901        for col in 0..4 {
902            let i = col * 4;
903            let (s0, s1, s2, s3) = (state[i], state[i + 1], state[i + 2], state[i + 3]);
904            state[i] = Self::gf_mul(2, s0) ^ Self::gf_mul(3, s1) ^ s2 ^ s3;
905            state[i + 1] = s0 ^ Self::gf_mul(2, s1) ^ Self::gf_mul(3, s2) ^ s3;
906            state[i + 2] = s0 ^ s1 ^ Self::gf_mul(2, s2) ^ Self::gf_mul(3, s3);
907            state[i + 3] = Self::gf_mul(3, s0) ^ s1 ^ s2 ^ Self::gf_mul(2, s3);
908        }
909    }
910
911    fn add_round_key(state: &mut [u8; 16], round_key: &[u8; 16]) {
912        for (s, k) in state.iter_mut().zip(round_key.iter()) {
913            *s ^= k;
914        }
915    }
916
917    fn encrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
918        let mut state = *block;
919        Self::add_round_key(&mut state, &self.round_keys[0]);
920
921        for round in 1..10 {
922            Self::sub_bytes(&mut state);
923            Self::shift_rows(&mut state);
924            Self::mix_columns(&mut state);
925            Self::add_round_key(&mut state, &self.round_keys[round]);
926        }
927
928        Self::sub_bytes(&mut state);
929        Self::shift_rows(&mut state);
930        Self::add_round_key(&mut state, &self.round_keys[10]);
931
932        state
933    }
934}
935
936/// GCM GHASH multiplication in GF(2^128)
937fn ghash_multiply(x: &[u8; 16], h: &[u8; 16]) -> [u8; 16] {
938    let mut z = [0u8; 16];
939    let mut v = *h;
940
941    for i in 0..128 {
942        let byte_idx = i / 8;
943        let bit_idx = 7 - (i % 8);
944        if (x[byte_idx] >> bit_idx) & 1 == 1 {
945            for j in 0..16 {
946                z[j] ^= v[j];
947            }
948        }
949
950        // Shift V right by 1 and reduce if needed
951        let lsb = v[15] & 1;
952        for j in (1..16).rev() {
953            v[j] = (v[j] >> 1) | (v[j - 1] << 7);
954        }
955        v[0] >>= 1;
956
957        if lsb == 1 {
958            v[0] ^= 0xE1; // R = 0xE1 || 0^120
959        }
960    }
961
962    z
963}
964
965/// GHASH function for GCM
966fn ghash(h: &[u8; 16], aad: &[u8], ciphertext: &[u8]) -> [u8; 16] {
967    let mut tag = [0u8; 16];
968
969    // Process AAD
970    for chunk in aad.chunks(16) {
971        let mut block = [0u8; 16];
972        block[..chunk.len()].copy_from_slice(chunk);
973        for i in 0..16 {
974            tag[i] ^= block[i];
975        }
976        tag = ghash_multiply(&tag, h);
977    }
978
979    // Process ciphertext
980    for chunk in ciphertext.chunks(16) {
981        let mut block = [0u8; 16];
982        block[..chunk.len()].copy_from_slice(chunk);
983        for i in 0..16 {
984            tag[i] ^= block[i];
985        }
986        tag = ghash_multiply(&tag, h);
987    }
988
989    // Length block: len(A) || len(C) in bits, big-endian 64-bit
990    let mut len_block = [0u8; 16];
991    let aad_bits = (aad.len() as u64).wrapping_mul(8);
992    let ct_bits = (ciphertext.len() as u64).wrapping_mul(8);
993    len_block[..8].copy_from_slice(&aad_bits.to_be_bytes());
994    len_block[8..16].copy_from_slice(&ct_bits.to_be_bytes());
995    for i in 0..16 {
996        tag[i] ^= len_block[i];
997    }
998    tag = ghash_multiply(&tag, h);
999
1000    tag
1001}
1002
1003/// AES-128-GCM encrypt
1004pub(crate) fn aes128_gcm_encrypt(
1005    key: &[u8; 16],
1006    nonce: &[u8; 12],
1007    aad: &[u8],
1008    plaintext: &[u8],
1009) -> Vec<u8> {
1010    let cipher = Aes128::new(key);
1011
1012    // H = AES_K(0^128)
1013    let h = cipher.encrypt_block(&[0u8; 16]);
1014
1015    // J0 = nonce || 0x00000001 (for 96-bit nonce)
1016    let mut j0 = [0u8; 16];
1017    j0[..12].copy_from_slice(nonce);
1018    j0[15] = 1;
1019
1020    // Encrypt plaintext with counter starting at J0 + 1
1021    let mut ciphertext = Vec::with_capacity(plaintext.len());
1022    let mut counter = 2u32;
1023    for chunk in plaintext.chunks(16) {
1024        let mut cb = j0;
1025        cb[12..16].copy_from_slice(&counter.to_be_bytes());
1026        let keystream = cipher.encrypt_block(&cb);
1027        for (i, &b) in chunk.iter().enumerate() {
1028            ciphertext.push(b ^ keystream[i]);
1029        }
1030        counter = counter.wrapping_add(1);
1031    }
1032
1033    // Compute GHASH
1034    let ghash_val = ghash(&h, aad, &ciphertext);
1035
1036    // Tag = GHASH XOR AES_K(J0)
1037    let j0_encrypted = cipher.encrypt_block(&j0);
1038    let mut tag = [0u8; 16];
1039    for i in 0..16 {
1040        tag[i] = ghash_val[i] ^ j0_encrypted[i];
1041    }
1042
1043    ciphertext.extend_from_slice(&tag);
1044    ciphertext
1045}
1046
1047/// AES-128-GCM decrypt
1048pub(crate) fn aes128_gcm_decrypt(
1049    key: &[u8; 16],
1050    nonce: &[u8; 12],
1051    aad: &[u8],
1052    ciphertext_and_tag: &[u8],
1053) -> Option<Vec<u8>> {
1054    if ciphertext_and_tag.len() < AEAD_TAG_LEN {
1055        return None;
1056    }
1057
1058    let ct_len = ciphertext_and_tag.len() - AEAD_TAG_LEN;
1059    let ciphertext = &ciphertext_and_tag[..ct_len];
1060    let received_tag = &ciphertext_and_tag[ct_len..];
1061
1062    let cipher = Aes128::new(key);
1063    let h = cipher.encrypt_block(&[0u8; 16]);
1064
1065    let mut j0 = [0u8; 16];
1066    j0[..12].copy_from_slice(nonce);
1067    j0[15] = 1;
1068
1069    // Verify tag first
1070    let ghash_val = ghash(&h, aad, ciphertext);
1071    let j0_encrypted = cipher.encrypt_block(&j0);
1072    let mut expected_tag = [0u8; 16];
1073    for i in 0..16 {
1074        expected_tag[i] = ghash_val[i] ^ j0_encrypted[i];
1075    }
1076
1077    if !constant_time_eq(received_tag, &expected_tag) {
1078        return None;
1079    }
1080
1081    // Decrypt
1082    let mut plaintext = Vec::with_capacity(ct_len);
1083    let mut counter = 2u32;
1084    for chunk in ciphertext.chunks(16) {
1085        let mut cb = j0;
1086        cb[12..16].copy_from_slice(&counter.to_be_bytes());
1087        let keystream = cipher.encrypt_block(&cb);
1088        for (i, &b) in chunk.iter().enumerate() {
1089            plaintext.push(b ^ keystream[i]);
1090        }
1091        counter = counter.wrapping_add(1);
1092    }
1093
1094    Some(plaintext)
1095}
1096
1097// ============================================================================
1098// AEAD Dispatch
1099// ============================================================================
1100
1101/// AEAD encrypt dispatcher for the negotiated cipher suite
1102pub(crate) fn aead_encrypt(
1103    cipher: CipherSuite,
1104    key: &[u8],
1105    nonce: &[u8; NONCE_LEN],
1106    aad: &[u8],
1107    plaintext: &[u8],
1108) -> Option<Vec<u8>> {
1109    match cipher {
1110        CipherSuite::ChaCha20Poly1305Sha256 => {
1111            if key.len() != CHACHA20_KEY_LEN {
1112                return None;
1113            }
1114            let mut k = [0u8; 32];
1115            k.copy_from_slice(key);
1116            Some(chacha20_poly1305_encrypt(&k, nonce, aad, plaintext))
1117        }
1118        CipherSuite::Aes128GcmSha256 => {
1119            if key.len() != AES_128_KEY_LEN {
1120                return None;
1121            }
1122            let mut k = [0u8; 16];
1123            k.copy_from_slice(key);
1124            Some(aes128_gcm_encrypt(&k, nonce, aad, plaintext))
1125        }
1126    }
1127}
1128
1129/// AEAD decrypt dispatcher for the negotiated cipher suite
1130pub(crate) fn aead_decrypt(
1131    cipher: CipherSuite,
1132    key: &[u8],
1133    nonce: &[u8; NONCE_LEN],
1134    aad: &[u8],
1135    ciphertext_and_tag: &[u8],
1136) -> Option<Vec<u8>> {
1137    match cipher {
1138        CipherSuite::ChaCha20Poly1305Sha256 => {
1139            if key.len() != CHACHA20_KEY_LEN {
1140                return None;
1141            }
1142            let mut k = [0u8; 32];
1143            k.copy_from_slice(key);
1144            chacha20_poly1305_decrypt(&k, nonce, aad, ciphertext_and_tag)
1145        }
1146        CipherSuite::Aes128GcmSha256 => {
1147            if key.len() != AES_128_KEY_LEN {
1148                return None;
1149            }
1150            let mut k = [0u8; 16];
1151            k.copy_from_slice(key);
1152            aes128_gcm_decrypt(&k, nonce, aad, ciphertext_and_tag)
1153        }
1154    }
1155}
1156
1157/// Constant-time comparison of two byte slices
1158pub(crate) fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
1159    if a.len() != b.len() {
1160        return false;
1161    }
1162    let mut diff = 0u8;
1163    for i in 0..a.len() {
1164        diff |= a[i] ^ b[i];
1165    }
1166    diff == 0
1167}