1#![allow(dead_code)]
8
9use alloc::vec::Vec;
10
11use super::{CryptoError, CryptoResult};
12
13pub(crate) trait SymmetricCipher {
15 fn encrypt(&self, plaintext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>>;
17
18 fn decrypt(&self, ciphertext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>>;
20
21 fn key_size(&self) -> usize;
23
24 fn nonce_size(&self) -> usize;
26
27 fn tag_size(&self) -> usize;
29}
30
31const AES_SBOX: [u8; 256] = [
33 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
34 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
35 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
36 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
37 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
38 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
39 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
40 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
41 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
42 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
43 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
44 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
45 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
46 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
47 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
48 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
49];
50
51const AES_RCON: [u8; 11] = [
53 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
54];
55
56struct Aes256 {
58 round_keys: [[u8; 16]; 15], }
60
61impl Aes256 {
62 fn new(key: &[u8; 32]) -> Self {
64 let mut round_keys = [[0u8; 16]; 15];
65 Self::key_expansion(key, &mut round_keys);
66 Self { round_keys }
67 }
68
69 fn key_expansion(key: &[u8; 32], round_keys: &mut [[u8; 16]; 15]) {
71 let mut w = [0u8; 240]; w[..32].copy_from_slice(key);
75
76 for i in 8..60 {
78 let mut temp = [w[i * 4 - 4], w[i * 4 - 3], w[i * 4 - 2], w[i * 4 - 1]];
79
80 if i % 8 == 0 {
81 temp = [
83 AES_SBOX[temp[1] as usize] ^ AES_RCON[i / 8],
84 AES_SBOX[temp[2] as usize],
85 AES_SBOX[temp[3] as usize],
86 AES_SBOX[temp[0] as usize],
87 ];
88 } else if i % 8 == 4 {
89 temp = [
91 AES_SBOX[temp[0] as usize],
92 AES_SBOX[temp[1] as usize],
93 AES_SBOX[temp[2] as usize],
94 AES_SBOX[temp[3] as usize],
95 ];
96 }
97
98 w[i * 4] = w[i * 4 - 32] ^ temp[0];
99 w[i * 4 + 1] = w[i * 4 - 31] ^ temp[1];
100 w[i * 4 + 2] = w[i * 4 - 30] ^ temp[2];
101 w[i * 4 + 3] = w[i * 4 - 29] ^ temp[3];
102 }
103
104 for (i, rk) in round_keys.iter_mut().enumerate() {
106 rk.copy_from_slice(&w[i * 16..(i + 1) * 16]);
107 }
108 }
109
110 #[inline]
112 fn sub_bytes(state: &mut [u8; 16]) {
113 for byte in state.iter_mut() {
114 *byte = AES_SBOX[*byte as usize];
115 }
116 }
117
118 #[inline]
120 fn shift_rows(state: &mut [u8; 16]) {
121 let temp = *state;
122 state[1] = temp[5];
125 state[5] = temp[9];
126 state[9] = temp[13];
127 state[13] = temp[1];
128 state[2] = temp[10];
130 state[6] = temp[14];
131 state[10] = temp[2];
132 state[14] = temp[6];
133 state[3] = temp[15];
135 state[7] = temp[3];
136 state[11] = temp[7];
137 state[15] = temp[11];
138 }
139
140 #[inline]
142 fn gf_mul(a: u8, b: u8) -> u8 {
143 let mut result = 0u8;
144 let mut aa = a;
145 let mut bb = b;
146 for _ in 0..8 {
147 if bb & 1 != 0 {
148 result ^= aa;
149 }
150 let hi_bit = aa & 0x80;
151 aa <<= 1;
152 if hi_bit != 0 {
153 aa ^= 0x1b; }
155 bb >>= 1;
156 }
157 result
158 }
159
160 #[inline]
162 fn mix_columns(state: &mut [u8; 16]) {
163 for col in 0..4 {
164 let i = col * 4;
165 let s0 = state[i];
166 let s1 = state[i + 1];
167 let s2 = state[i + 2];
168 let s3 = state[i + 3];
169
170 state[i] = Self::gf_mul(2, s0) ^ Self::gf_mul(3, s1) ^ s2 ^ s3;
171 state[i + 1] = s0 ^ Self::gf_mul(2, s1) ^ Self::gf_mul(3, s2) ^ s3;
172 state[i + 2] = s0 ^ s1 ^ Self::gf_mul(2, s2) ^ Self::gf_mul(3, s3);
173 state[i + 3] = Self::gf_mul(3, s0) ^ s1 ^ s2 ^ Self::gf_mul(2, s3);
174 }
175 }
176
177 #[inline]
179 fn add_round_key(state: &mut [u8; 16], round_key: &[u8; 16]) {
180 for (s, k) in state.iter_mut().zip(round_key.iter()) {
181 *s ^= k;
182 }
183 }
184
185 fn encrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
187 let mut state = *block;
188
189 Self::add_round_key(&mut state, &self.round_keys[0]);
191
192 for round in 1..14 {
194 Self::sub_bytes(&mut state);
195 Self::shift_rows(&mut state);
196 Self::mix_columns(&mut state);
197 Self::add_round_key(&mut state, &self.round_keys[round]);
198 }
199
200 Self::sub_bytes(&mut state);
202 Self::shift_rows(&mut state);
203 Self::add_round_key(&mut state, &self.round_keys[14]);
204
205 state
206 }
207}
208
209pub(crate) struct Aes256Gcm {
211 aes: Aes256,
212 h: [u8; 16], }
214
215impl Aes256Gcm {
216 pub(crate) fn new(key: &[u8]) -> CryptoResult<Self> {
218 if key.len() != 32 {
219 return Err(CryptoError::InvalidKeySize);
220 }
221
222 let mut key_array = [0u8; 32];
223 key_array.copy_from_slice(key);
224
225 let aes = Aes256::new(&key_array);
226
227 let h = aes.encrypt_block(&[0u8; 16]);
229
230 Ok(Self { aes, h })
231 }
232
233 fn gcm_mult(&self, x: &[u8; 16]) -> [u8; 16] {
235 let mut z = [0u8; 16];
236 let mut v = self.h;
237
238 for i in 0..128 {
239 let bit = (x[i / 8] >> (7 - (i % 8))) & 1;
240 if bit == 1 {
241 for j in 0..16 {
242 z[j] ^= v[j];
243 }
244 }
245
246 let lsb = v[15] & 1;
247 for j in (1..16).rev() {
249 v[j] = (v[j] >> 1) | (v[j - 1] << 7);
250 }
251 v[0] >>= 1;
252
253 if lsb == 1 {
254 v[0] ^= 0xe1; }
256 }
257 z
258 }
259
260 fn ghash(&self, aad: &[u8], ciphertext: &[u8]) -> [u8; 16] {
262 let mut y = [0u8; 16];
263
264 for chunk in aad.chunks(16) {
266 let mut block = [0u8; 16];
267 block[..chunk.len()].copy_from_slice(chunk);
268 for i in 0..16 {
269 y[i] ^= block[i];
270 }
271 y = self.gcm_mult(&y);
272 }
273
274 for chunk in ciphertext.chunks(16) {
276 let mut block = [0u8; 16];
277 block[..chunk.len()].copy_from_slice(chunk);
278 for i in 0..16 {
279 y[i] ^= block[i];
280 }
281 y = self.gcm_mult(&y);
282 }
283
284 let aad_len_bits = (aad.len() as u64) * 8;
286 let ct_len_bits = (ciphertext.len() as u64) * 8;
287 let mut len_block = [0u8; 16];
288 len_block[..8].copy_from_slice(&aad_len_bits.to_be_bytes());
289 len_block[8..].copy_from_slice(&ct_len_bits.to_be_bytes());
290
291 for i in 0..16 {
292 y[i] ^= len_block[i];
293 }
294 self.gcm_mult(&y)
295 }
296
297 fn counter_block(nonce: &[u8], counter: u32) -> [u8; 16] {
299 let mut block = [0u8; 16];
300 block[..12].copy_from_slice(nonce);
301 block[12..16].copy_from_slice(&counter.to_be_bytes());
302 block
303 }
304}
305
306impl SymmetricCipher for Aes256Gcm {
307 fn encrypt(&self, plaintext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>> {
308 if nonce.len() != 12 {
309 return Err(CryptoError::InvalidNonceSize);
310 }
311
312 let mut ciphertext = Vec::with_capacity(plaintext.len() + 16);
313
314 let mut counter = 2u32; for chunk in plaintext.chunks(16) {
317 let counter_block = Self::counter_block(nonce, counter);
318 let keystream = self.aes.encrypt_block(&counter_block);
319
320 for (i, &p) in chunk.iter().enumerate() {
321 ciphertext.push(p ^ keystream[i]);
322 }
323 counter += 1;
324 }
325
326 let s = self.ghash(&[], &ciphertext);
328 let j0 = Self::counter_block(nonce, 1);
329 let encrypted_j0 = self.aes.encrypt_block(&j0);
330
331 for i in 0..16 {
332 ciphertext.push(s[i] ^ encrypted_j0[i]);
333 }
334
335 Ok(ciphertext)
336 }
337
338 fn decrypt(&self, ciphertext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>> {
339 if nonce.len() != 12 {
340 return Err(CryptoError::InvalidNonceSize);
341 }
342
343 if ciphertext.len() < 16 {
344 return Err(CryptoError::DecryptionFailed);
345 }
346
347 let data_len = ciphertext.len() - 16;
348 let (ct, tag) = ciphertext.split_at(data_len);
349
350 let s = self.ghash(&[], ct);
352 let j0 = Self::counter_block(nonce, 1);
353 let encrypted_j0 = self.aes.encrypt_block(&j0);
354
355 let mut expected_tag = [0u8; 16];
356 for i in 0..16 {
357 expected_tag[i] = s[i] ^ encrypted_j0[i];
358 }
359
360 let mut diff = 0u8;
362 for i in 0..16 {
363 diff |= tag[i] ^ expected_tag[i];
364 }
365 if diff != 0 {
366 return Err(CryptoError::DecryptionFailed);
367 }
368
369 let mut plaintext = Vec::with_capacity(data_len);
371 let mut counter = 2u32;
372 for chunk in ct.chunks(16) {
373 let counter_block = Self::counter_block(nonce, counter);
374 let keystream = self.aes.encrypt_block(&counter_block);
375
376 for (i, &c) in chunk.iter().enumerate() {
377 plaintext.push(c ^ keystream[i]);
378 }
379 counter += 1;
380 }
381
382 Ok(plaintext)
383 }
384
385 fn key_size(&self) -> usize {
386 32
387 }
388
389 fn nonce_size(&self) -> usize {
390 12
391 }
392
393 fn tag_size(&self) -> usize {
394 16
395 }
396}
397
398pub(crate) struct ChaCha20Poly1305 {
400 key: [u8; 32],
401}
402
403impl ChaCha20Poly1305 {
404 pub(crate) fn new(key: &[u8]) -> CryptoResult<Self> {
406 if key.len() != 32 {
407 return Err(CryptoError::InvalidKeySize);
408 }
409
410 let mut key_array = [0u8; 32];
411 key_array.copy_from_slice(key);
412
413 Ok(Self { key: key_array })
414 }
415
416 #[inline]
418 fn quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
419 state[a] = state[a].wrapping_add(state[b]);
420 state[d] ^= state[a];
421 state[d] = state[d].rotate_left(16);
422
423 state[c] = state[c].wrapping_add(state[d]);
424 state[b] ^= state[c];
425 state[b] = state[b].rotate_left(12);
426
427 state[a] = state[a].wrapping_add(state[b]);
428 state[d] ^= state[a];
429 state[d] = state[d].rotate_left(8);
430
431 state[c] = state[c].wrapping_add(state[d]);
432 state[b] ^= state[c];
433 state[b] = state[b].rotate_left(7);
434 }
435
436 fn chacha20_block(&self, nonce: &[u8], counter: u32) -> [u8; 64] {
438 let mut state: [u32; 16] = [
440 0x61707865,
441 0x3320646e,
442 0x79622d32,
443 0x6b206574, u32::from_le_bytes([self.key[0], self.key[1], self.key[2], self.key[3]]),
445 u32::from_le_bytes([self.key[4], self.key[5], self.key[6], self.key[7]]),
446 u32::from_le_bytes([self.key[8], self.key[9], self.key[10], self.key[11]]),
447 u32::from_le_bytes([self.key[12], self.key[13], self.key[14], self.key[15]]),
448 u32::from_le_bytes([self.key[16], self.key[17], self.key[18], self.key[19]]),
449 u32::from_le_bytes([self.key[20], self.key[21], self.key[22], self.key[23]]),
450 u32::from_le_bytes([self.key[24], self.key[25], self.key[26], self.key[27]]),
451 u32::from_le_bytes([self.key[28], self.key[29], self.key[30], self.key[31]]),
452 counter,
453 u32::from_le_bytes([nonce[0], nonce[1], nonce[2], nonce[3]]),
454 u32::from_le_bytes([nonce[4], nonce[5], nonce[6], nonce[7]]),
455 u32::from_le_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]),
456 ];
457
458 let initial_state = state;
459
460 for _ in 0..10 {
462 Self::quarter_round(&mut state, 0, 4, 8, 12);
464 Self::quarter_round(&mut state, 1, 5, 9, 13);
465 Self::quarter_round(&mut state, 2, 6, 10, 14);
466 Self::quarter_round(&mut state, 3, 7, 11, 15);
467 Self::quarter_round(&mut state, 0, 5, 10, 15);
469 Self::quarter_round(&mut state, 1, 6, 11, 12);
470 Self::quarter_round(&mut state, 2, 7, 8, 13);
471 Self::quarter_round(&mut state, 3, 4, 9, 14);
472 }
473
474 for i in 0..16 {
476 state[i] = state[i].wrapping_add(initial_state[i]);
477 }
478
479 let mut output = [0u8; 64];
481 for (i, &word) in state.iter().enumerate() {
482 output[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes());
483 }
484 output
485 }
486
487 fn poly1305_mac(&self, key: &[u8; 32], message: &[u8]) -> [u8; 16] {
493 let mut r_bytes = [0u8; 16];
495 r_bytes.copy_from_slice(&key[..16]);
496 r_bytes[3] &= 15;
497 r_bytes[7] &= 15;
498 r_bytes[11] &= 15;
499 r_bytes[15] &= 15;
500 r_bytes[4] &= 252;
501 r_bytes[8] &= 252;
502 r_bytes[12] &= 252;
503
504 let r0 = (u32::from_le_bytes([r_bytes[0], r_bytes[1], r_bytes[2], r_bytes[3]])) & 0x3ffffff;
506 let r1 =
507 (u32::from_le_bytes([r_bytes[3], r_bytes[4], r_bytes[5], r_bytes[6]]) >> 2) & 0x3ffffff;
508 let r2 =
509 (u32::from_le_bytes([r_bytes[6], r_bytes[7], r_bytes[8], r_bytes[9]]) >> 4) & 0x3ffffff;
510 let r3 = (u32::from_le_bytes([r_bytes[9], r_bytes[10], r_bytes[11], r_bytes[12]]) >> 6)
511 & 0x3ffffff;
512 let r4 = (u32::from_le_bytes([r_bytes[12], r_bytes[13], r_bytes[14], r_bytes[15]]) >> 8)
513 & 0x3ffffff;
514
515 let s1 = r1.wrapping_mul(5);
518 let s2 = r2.wrapping_mul(5);
519 let s3 = r3.wrapping_mul(5);
520 let s4 = r4.wrapping_mul(5);
521
522 let mut h0: u32 = 0;
524 let mut h1: u32 = 0;
525 let mut h2: u32 = 0;
526 let mut h3: u32 = 0;
527 let mut h4: u32 = 0;
528
529 for chunk in message.chunks(16) {
531 let mut block = [0u8; 17];
533 block[..chunk.len()].copy_from_slice(chunk);
534 block[chunk.len()] = 1; let t0 = u32::from_le_bytes([block[0], block[1], block[2], block[3]]);
538 let t1 = u32::from_le_bytes([block[4], block[5], block[6], block[7]]);
539 let t2 = u32::from_le_bytes([block[8], block[9], block[10], block[11]]);
540 let t3 = u32::from_le_bytes([block[12], block[13], block[14], block[15]]);
541
542 h0 = h0.wrapping_add(t0 & 0x3ffffff);
543 h1 = h1.wrapping_add(((t0 >> 26) | (t1 << 6)) & 0x3ffffff);
544 h2 = h2.wrapping_add(((t1 >> 20) | (t2 << 12)) & 0x3ffffff);
545 h3 = h3.wrapping_add(((t2 >> 14) | (t3 << 18)) & 0x3ffffff);
546 h4 = h4.wrapping_add((t3 >> 8) | ((block[16] as u32) << 24));
547
548 let d0 = (h0 as u64)
551 .wrapping_mul(r0 as u64)
552 .wrapping_add((h1 as u64).wrapping_mul(s4 as u64))
553 .wrapping_add((h2 as u64).wrapping_mul(s3 as u64))
554 .wrapping_add((h3 as u64).wrapping_mul(s2 as u64))
555 .wrapping_add((h4 as u64).wrapping_mul(s1 as u64));
556 let d1 = (h0 as u64)
557 .wrapping_mul(r1 as u64)
558 .wrapping_add((h1 as u64).wrapping_mul(r0 as u64))
559 .wrapping_add((h2 as u64).wrapping_mul(s4 as u64))
560 .wrapping_add((h3 as u64).wrapping_mul(s3 as u64))
561 .wrapping_add((h4 as u64).wrapping_mul(s2 as u64));
562 let d2 = (h0 as u64)
563 .wrapping_mul(r2 as u64)
564 .wrapping_add((h1 as u64).wrapping_mul(r1 as u64))
565 .wrapping_add((h2 as u64).wrapping_mul(r0 as u64))
566 .wrapping_add((h3 as u64).wrapping_mul(s4 as u64))
567 .wrapping_add((h4 as u64).wrapping_mul(s3 as u64));
568 let d3 = (h0 as u64)
569 .wrapping_mul(r3 as u64)
570 .wrapping_add((h1 as u64).wrapping_mul(r2 as u64))
571 .wrapping_add((h2 as u64).wrapping_mul(r1 as u64))
572 .wrapping_add((h3 as u64).wrapping_mul(r0 as u64))
573 .wrapping_add((h4 as u64).wrapping_mul(s4 as u64));
574 let d4 = (h0 as u64)
575 .wrapping_mul(r4 as u64)
576 .wrapping_add((h1 as u64).wrapping_mul(r3 as u64))
577 .wrapping_add((h2 as u64).wrapping_mul(r2 as u64))
578 .wrapping_add((h3 as u64).wrapping_mul(r1 as u64))
579 .wrapping_add((h4 as u64).wrapping_mul(r0 as u64));
580
581 let mut c: u32;
583 c = (d0 >> 26) as u32;
584 h0 = d0 as u32 & 0x3ffffff;
585 let d1 = d1.wrapping_add(c as u64);
586 c = (d1 >> 26) as u32;
587 h1 = d1 as u32 & 0x3ffffff;
588 let d2 = d2.wrapping_add(c as u64);
589 c = (d2 >> 26) as u32;
590 h2 = d2 as u32 & 0x3ffffff;
591 let d3 = d3.wrapping_add(c as u64);
592 c = (d3 >> 26) as u32;
593 h3 = d3 as u32 & 0x3ffffff;
594 let d4 = d4.wrapping_add(c as u64);
595 c = (d4 >> 26) as u32;
596 h4 = d4 as u32 & 0x3ffffff;
597 h0 = h0.wrapping_add(c.wrapping_mul(5));
599 c = h0 >> 26;
600 h0 &= 0x3ffffff;
601 h1 = h1.wrapping_add(c);
602 }
603
604 let mut c: u32;
606 c = h1 >> 26;
607 h1 &= 0x3ffffff;
608 h2 = h2.wrapping_add(c);
609 c = h2 >> 26;
610 h2 &= 0x3ffffff;
611 h3 = h3.wrapping_add(c);
612 c = h3 >> 26;
613 h3 &= 0x3ffffff;
614 h4 = h4.wrapping_add(c);
615 c = h4 >> 26;
616 h4 &= 0x3ffffff;
617 h0 = h0.wrapping_add(c.wrapping_mul(5));
618 c = h0 >> 26;
619 h0 &= 0x3ffffff;
620 h1 = h1.wrapping_add(c);
621
622 let mut g0 = h0.wrapping_add(5);
624 c = g0 >> 26;
625 g0 &= 0x3ffffff;
626 let mut g1 = h1.wrapping_add(c);
627 c = g1 >> 26;
628 g1 &= 0x3ffffff;
629 let mut g2 = h2.wrapping_add(c);
630 c = g2 >> 26;
631 g2 &= 0x3ffffff;
632 let mut g3 = h3.wrapping_add(c);
633 c = g3 >> 26;
634 g3 &= 0x3ffffff;
635 let g4 = h4.wrapping_add(c).wrapping_sub(1 << 26);
636
637 let mask = (g4 >> 31).wrapping_sub(1); g0 &= mask;
641 g1 &= mask;
642 g2 &= mask;
643 g3 &= mask;
644 let nmask = !mask;
645 h0 = (h0 & nmask) | g0;
646 h1 = (h1 & nmask) | g1;
647 h2 = (h2 & nmask) | g2;
648 h3 = (h3 & nmask) | g3;
649
650 let s_bytes = &key[16..32];
652 let s0 = u64::from_le_bytes([
653 s_bytes[0], s_bytes[1], s_bytes[2], s_bytes[3], s_bytes[4], s_bytes[5], s_bytes[6],
654 s_bytes[7],
655 ]);
656 let s1_val = u64::from_le_bytes([
657 s_bytes[8],
658 s_bytes[9],
659 s_bytes[10],
660 s_bytes[11],
661 s_bytes[12],
662 s_bytes[13],
663 s_bytes[14],
664 s_bytes[15],
665 ]);
666
667 let mut f: u64;
668 f = (h0 as u64) | ((h1 as u64) << 26) | ((h2 as u64) << 52);
669 let low = f;
670 f = ((h2 as u64) >> 12) | ((h3 as u64) << 14) | ((h4 as u64) << 40);
671 let high = f;
672
673 let (result0, carry_bit) = low.overflowing_add(s0);
674 let result1 = high.wrapping_add(s1_val).wrapping_add(carry_bit as u64);
675
676 let mut tag = [0u8; 16];
677 tag[..8].copy_from_slice(&result0.to_le_bytes());
678 tag[8..].copy_from_slice(&result1.to_le_bytes());
679 tag
680 }
681}
682
683impl SymmetricCipher for ChaCha20Poly1305 {
684 fn encrypt(&self, plaintext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>> {
685 if nonce.len() != 12 {
686 return Err(CryptoError::InvalidNonceSize);
687 }
688
689 let poly_key_block = self.chacha20_block(nonce, 0);
691 let mut poly_key = [0u8; 32];
692 poly_key.copy_from_slice(&poly_key_block[..32]);
693
694 let mut ciphertext = Vec::with_capacity(plaintext.len() + 16);
696 let mut counter = 1u32;
697
698 for chunk in plaintext.chunks(64) {
699 let keystream = self.chacha20_block(nonce, counter);
700 for (i, &p) in chunk.iter().enumerate() {
701 ciphertext.push(p ^ keystream[i]);
702 }
703 counter += 1;
704 }
705
706 let mut auth_data = Vec::new();
708 auth_data.extend_from_slice(&ciphertext);
710 while auth_data.len() % 16 != 0 {
712 auth_data.push(0);
713 }
714 auth_data.extend_from_slice(&0u64.to_le_bytes()); auth_data.extend_from_slice(&(ciphertext.len() as u64).to_le_bytes());
717
718 let tag = self.poly1305_mac(&poly_key, &auth_data);
720 ciphertext.extend_from_slice(&tag);
721
722 Ok(ciphertext)
723 }
724
725 fn decrypt(&self, ciphertext: &[u8], nonce: &[u8]) -> CryptoResult<Vec<u8>> {
726 if nonce.len() != 12 {
727 return Err(CryptoError::InvalidNonceSize);
728 }
729
730 if ciphertext.len() < 16 {
731 return Err(CryptoError::DecryptionFailed);
732 }
733
734 let data_len = ciphertext.len() - 16;
735 let (ct, tag) = ciphertext.split_at(data_len);
736
737 let poly_key_block = self.chacha20_block(nonce, 0);
739 let mut poly_key = [0u8; 32];
740 poly_key.copy_from_slice(&poly_key_block[..32]);
741
742 let mut auth_data = Vec::new();
744 auth_data.extend_from_slice(ct);
745 while auth_data.len() % 16 != 0 {
746 auth_data.push(0);
747 }
748 auth_data.extend_from_slice(&0u64.to_le_bytes());
749 auth_data.extend_from_slice(&(ct.len() as u64).to_le_bytes());
750
751 let expected_tag = self.poly1305_mac(&poly_key, &auth_data);
752
753 let mut diff = 0u8;
755 for i in 0..16 {
756 diff |= tag[i] ^ expected_tag[i];
757 }
758 if diff != 0 {
759 return Err(CryptoError::DecryptionFailed);
760 }
761
762 let mut plaintext = Vec::with_capacity(data_len);
764 let mut counter = 1u32;
765
766 for chunk in ct.chunks(64) {
767 let keystream = self.chacha20_block(nonce, counter);
768 for (i, &c) in chunk.iter().enumerate() {
769 plaintext.push(c ^ keystream[i]);
770 }
771 counter += 1;
772 }
773
774 Ok(plaintext)
775 }
776
777 fn key_size(&self) -> usize {
778 32
779 }
780
781 fn nonce_size(&self) -> usize {
782 12
783 }
784
785 fn tag_size(&self) -> usize {
786 16
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use super::*;
793
794 #[test]
795 fn test_aes256gcm_encrypt_decrypt() {
796 let key = [0x42u8; 32];
797 let cipher = Aes256Gcm::new(&key).unwrap();
798 let nonce = [0x12u8; 12];
799 let plaintext = b"Hello, VeridianOS!";
800
801 let ciphertext = cipher.encrypt(plaintext, &nonce).unwrap();
802 let decrypted = cipher.decrypt(&ciphertext, &nonce).unwrap();
803
804 assert_eq!(plaintext.as_ref(), decrypted.as_slice());
805 }
806}