1#![allow(unexpected_cfgs)]
2#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11use alloc::vec::Vec;
12
13const MAX_PCRS: usize = 24;
15
16const DIGEST_LEN: usize = 32;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum BootStage {
22 Firmware = 0,
24 Bootloader = 1,
26 Kernel = 2,
28 InitSystem = 3,
30 DriverFramework = 4,
32 UserSpace = 5,
34}
35
36impl BootStage {
37 const COUNT: usize = 6;
39
40 fn from_index(idx: usize) -> Option<Self> {
42 match idx {
43 0 => Some(Self::Firmware),
44 1 => Some(Self::Bootloader),
45 2 => Some(Self::Kernel),
46 3 => Some(Self::InitSystem),
47 4 => Some(Self::DriverFramework),
48 5 => Some(Self::UserSpace),
49 _ => None,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum BootStatus {
57 #[default]
59 NotStarted,
60 Measuring,
62 Verifying,
64 Approved,
66 Rejected,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PolicyDecision {
73 Allow,
75 Warn,
77 Deny,
79}
80
81#[derive(Debug, Clone)]
83pub struct PcrState {
84 values: [[u8; DIGEST_LEN]; MAX_PCRS],
86 extend_count: [u32; MAX_PCRS],
88}
89
90impl Default for PcrState {
91 fn default() -> Self {
92 Self {
93 values: [[0u8; DIGEST_LEN]; MAX_PCRS],
94 extend_count: [0u32; MAX_PCRS],
95 }
96 }
97}
98
99impl PcrState {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn extend(
109 &mut self,
110 pcr_index: usize,
111 digest: &[u8; DIGEST_LEN],
112 ) -> Result<[u8; DIGEST_LEN], BootVerifyError> {
113 if pcr_index >= MAX_PCRS {
114 return Err(BootVerifyError::InvalidPcrIndex);
115 }
116
117 let mut concat = [0u8; DIGEST_LEN * 2];
119 concat[..DIGEST_LEN].copy_from_slice(&self.values[pcr_index]);
120 concat[DIGEST_LEN..].copy_from_slice(digest);
121
122 let new_value = simple_sha256_model(&concat);
125 self.values[pcr_index] = new_value;
126 self.extend_count[pcr_index] = self.extend_count[pcr_index].saturating_add(1);
127
128 Ok(new_value)
129 }
130
131 pub fn get(&self, pcr_index: usize) -> Option<&[u8; DIGEST_LEN]> {
133 if pcr_index < MAX_PCRS {
134 Some(&self.values[pcr_index])
135 } else {
136 None
137 }
138 }
139
140 pub fn get_extend_count(&self, pcr_index: usize) -> Option<u32> {
142 if pcr_index < MAX_PCRS {
143 Some(self.extend_count[pcr_index])
144 } else {
145 None
146 }
147 }
148
149 pub fn is_extended(&self, pcr_index: usize) -> bool {
151 pcr_index < MAX_PCRS && self.extend_count[pcr_index] > 0
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct MeasurementEntry {
158 pub pcr_index: usize,
160 pub digest: [u8; DIGEST_LEN],
162 pub sequence: u64,
164 pub description: MeasuredComponent,
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum MeasuredComponent {
171 FirmwareCode,
173 BootloaderBinary,
175 KernelImage,
177 KernelCmdline,
179 InitBinary,
181 DriverBinary,
183 UserComponent,
185}
186
187#[derive(Debug, Clone, Default)]
189pub struct MeasurementLog {
190 #[cfg(feature = "alloc")]
192 entries: Vec<MeasurementEntry>,
193 #[cfg(not(feature = "alloc"))]
194 entries_count: usize,
195 next_sequence: u64,
197}
198
199impl MeasurementLog {
200 pub fn new() -> Self {
202 Self::default()
203 }
204
205 #[cfg(feature = "alloc")]
207 pub fn add(
208 &mut self,
209 pcr_index: usize,
210 digest: [u8; DIGEST_LEN],
211 component: MeasuredComponent,
212 ) {
213 let entry = MeasurementEntry {
214 pcr_index,
215 digest,
216 sequence: self.next_sequence,
217 description: component,
218 };
219 self.entries.push(entry);
220 self.next_sequence += 1;
221 }
222
223 #[cfg(feature = "alloc")]
225 pub fn len(&self) -> usize {
226 self.entries.len()
227 }
228
229 #[cfg(feature = "alloc")]
231 pub fn is_empty(&self) -> bool {
232 self.entries.is_empty()
233 }
234
235 #[cfg(feature = "alloc")]
237 pub fn entries(&self) -> &[MeasurementEntry] {
238 &self.entries
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
244pub enum BootVerifyError {
245 InvalidPcrIndex,
247 PcrReset,
249 MissingMeasurement,
251 HashChainBroken,
253 PolicyViolation,
255 LogOutOfOrder,
257 CountMismatch,
259}
260
261#[derive(Debug, Default)]
263pub struct BootChainVerifier {
264 pcr_state: PcrState,
266 #[cfg(feature = "alloc")]
268 log: MeasurementLog,
269 status: BootStatus,
271 stages_measured: [bool; BootStage::COUNT],
273 #[cfg(feature = "alloc")]
275 expected_pcrs: Vec<(usize, [u8; DIGEST_LEN])>,
276}
277
278impl BootChainVerifier {
279 pub fn new() -> Self {
281 Self::default()
282 }
283
284 #[cfg(feature = "alloc")]
286 pub fn measure(
287 &mut self,
288 stage: BootStage,
289 pcr_index: usize,
290 digest: [u8; DIGEST_LEN],
291 component: MeasuredComponent,
292 ) -> Result<(), BootVerifyError> {
293 if self.status == BootStatus::NotStarted {
294 self.status = BootStatus::Measuring;
295 }
296
297 self.pcr_state.extend(pcr_index, &digest)?;
299
300 self.log.add(pcr_index, digest, component);
302
303 self.stages_measured[stage as usize] = true;
305
306 Ok(())
307 }
308
309 #[cfg(feature = "alloc")]
311 pub fn set_expected_pcr(&mut self, pcr_index: usize, expected: [u8; DIGEST_LEN]) {
312 self.expected_pcrs.push((pcr_index, expected));
313 }
314
315 pub fn verify_pcr_monotonicity(&self) -> Result<(), BootVerifyError> {
321 for i in 0..MAX_PCRS {
323 if self.pcr_state.extend_count[i] > 0 {
324 let all_zero = self.pcr_state.values[i].iter().all(|&b| b == 0);
325 if all_zero {
326 return Err(BootVerifyError::PcrReset);
327 }
328 }
329 }
330 Ok(())
331 }
332
333 pub fn verify_measurement_completeness(&self) -> Result<(), BootVerifyError> {
335 for (i, measured) in self.stages_measured.iter().enumerate() {
336 if !measured {
337 let _stage = BootStage::from_index(i);
338 return Err(BootVerifyError::MissingMeasurement);
339 }
340 }
341 Ok(())
342 }
343
344 #[cfg(feature = "alloc")]
346 pub fn verify_hash_chain(&self) -> Result<(), BootVerifyError> {
347 let mut replay_pcrs = PcrState::new();
349
350 for entry in self.log.entries() {
351 replay_pcrs
352 .extend(entry.pcr_index, &entry.digest)
353 .map_err(|_| BootVerifyError::HashChainBroken)?;
354 }
355
356 for i in 0..MAX_PCRS {
358 if replay_pcrs.values[i] != self.pcr_state.values[i] {
359 return Err(BootVerifyError::HashChainBroken);
360 }
361 }
362
363 Ok(())
364 }
365
366 #[cfg(feature = "alloc")]
368 pub fn verify_boot_policy(&self) -> Result<PolicyDecision, BootVerifyError> {
369 for (pcr_index, expected) in &self.expected_pcrs {
370 if let Some(actual) = self.pcr_state.get(*pcr_index) {
371 if actual != expected {
372 return Ok(PolicyDecision::Deny);
373 }
374 } else {
375 return Err(BootVerifyError::InvalidPcrIndex);
376 }
377 }
378 Ok(PolicyDecision::Allow)
379 }
380
381 #[cfg(feature = "alloc")]
383 pub fn verify_log_ordering(&self) -> Result<(), BootVerifyError> {
384 let entries = self.log.entries();
385 for window in entries.windows(2) {
386 if window[0].sequence >= window[1].sequence {
387 return Err(BootVerifyError::LogOutOfOrder);
388 }
389 }
390 Ok(())
391 }
392
393 #[cfg(feature = "alloc")]
395 pub fn verify_measurement_count(&self) -> Result<(), BootVerifyError> {
396 let mut log_counts = [0u32; MAX_PCRS];
398 for entry in self.log.entries() {
399 if entry.pcr_index < MAX_PCRS {
400 log_counts[entry.pcr_index] = log_counts[entry.pcr_index].saturating_add(1);
401 }
402 }
403
404 for (i, &count) in log_counts.iter().enumerate().take(MAX_PCRS) {
406 if count != self.pcr_state.extend_count[i] {
407 return Err(BootVerifyError::CountMismatch);
408 }
409 }
410
411 Ok(())
412 }
413
414 pub fn status(&self) -> BootStatus {
416 self.status
417 }
418
419 pub fn set_status(&mut self, new_status: BootStatus) -> Result<(), BootVerifyError> {
421 let valid = matches!(
423 (self.status, new_status),
424 (BootStatus::NotStarted, BootStatus::Measuring)
425 | (BootStatus::Measuring, BootStatus::Verifying)
426 | (BootStatus::Verifying, BootStatus::Approved)
427 | (BootStatus::Verifying, BootStatus::Rejected)
428 );
429
430 if valid {
431 self.status = new_status;
432 Ok(())
433 } else {
434 Err(BootVerifyError::PolicyViolation)
435 }
436 }
437}
438
439fn simple_sha256_model(input: &[u8]) -> [u8; DIGEST_LEN] {
445 let mut output = [0u8; DIGEST_LEN];
446
447 let mut state: u64 = 0x6a09_e667_bb67_ae85;
449 for (i, &byte) in input.iter().enumerate() {
450 state = state
451 .wrapping_mul(0x0100_0000_01b3)
452 .wrapping_add(byte as u64);
453 let idx = i % DIGEST_LEN;
454 output[idx] ^= (state & 0xFF) as u8;
455 output[(idx + 7) % DIGEST_LEN] ^= ((state >> 8) & 0xFF) as u8;
456 output[(idx + 13) % DIGEST_LEN] ^= ((state >> 16) & 0xFF) as u8;
457 output[(idx + 21) % DIGEST_LEN] ^= ((state >> 24) & 0xFF) as u8;
458 }
459
460 output
461}
462
463#[cfg(kani)]
468mod kani_proofs {
469 use super::*;
470
471 #[kani::proof]
474 fn proof_pcr_extend_monotonic() {
475 let mut pcr = PcrState::new();
476 let digest: [u8; DIGEST_LEN] = kani::any();
477 kani::assume(digest.iter().any(|&b| b != 0)); let before = pcr.values[0];
480 let _ = pcr.extend(0, &digest);
481 let after = pcr.values[0];
482
483 assert!(before != after, "PCR value must change after extend");
484 }
485
486 #[kani::proof]
488 fn proof_pcr_extend_deterministic() {
489 let digest: [u8; DIGEST_LEN] = kani::any();
490
491 let mut pcr1 = PcrState::new();
492 let mut pcr2 = PcrState::new();
493
494 let r1 = pcr1.extend(0, &digest);
495 let r2 = pcr2.extend(0, &digest);
496
497 assert_eq!(r1, r2, "Same input must produce same output");
498 }
499
500 #[kani::proof]
502 fn proof_measurement_log_ordered() {
503 let mut log = MeasurementLog::new();
504 let d1: [u8; DIGEST_LEN] = kani::any();
505 let d2: [u8; DIGEST_LEN] = kani::any();
506
507 log.add(0, d1, MeasuredComponent::FirmwareCode);
508 log.add(0, d2, MeasuredComponent::BootloaderBinary);
509
510 let entries = log.entries();
511 assert!(entries[0].sequence < entries[1].sequence);
512 }
513
514 #[kani::proof]
516 fn proof_boot_status_transitions() {
517 let mut verifier = BootChainVerifier::new();
518 assert_eq!(verifier.status(), BootStatus::NotStarted);
519
520 assert!(verifier.set_status(BootStatus::Measuring).is_ok());
522
523 assert!(verifier.set_status(BootStatus::Approved).is_err());
525
526 assert!(verifier.set_status(BootStatus::Verifying).is_ok());
528
529 assert!(verifier.set_status(BootStatus::Approved).is_ok());
531 }
532
533 #[kani::proof]
535 fn proof_policy_decision_complete() {
536 let status: u8 = kani::any();
537 kani::assume(status < 5);
538
539 let decision = match status {
540 0 => PolicyDecision::Deny, 1 => PolicyDecision::Warn, 2 => PolicyDecision::Warn, 3 => PolicyDecision::Allow, 4 => PolicyDecision::Deny, _ => unreachable!(),
546 };
547
548 assert!(matches!(
550 decision,
551 PolicyDecision::Allow | PolicyDecision::Warn | PolicyDecision::Deny
552 ));
553 }
554
555 #[kani::proof]
557 fn proof_hash_chain_integrity() {
558 let digest: [u8; DIGEST_LEN] = kani::any();
559
560 let mut pcr1 = PcrState::new();
561 let mut pcr2 = PcrState::new();
562
563 let _ = pcr1.extend(0, &digest);
565 let _ = pcr2.extend(0, &digest);
566
567 assert_eq!(pcr1.values[0], pcr2.values[0]);
569 }
570
571 #[kani::proof]
573 fn proof_pcr_no_reset() {
574 let mut pcr = PcrState::new();
575 let d1: [u8; DIGEST_LEN] = kani::any();
576
577 let _ = pcr.extend(0, &d1);
578 let count_after_first = pcr.extend_count[0];
579
580 let d2: [u8; DIGEST_LEN] = kani::any();
581 let _ = pcr.extend(0, &d2);
582 let count_after_second = pcr.extend_count[0];
583
584 assert!(
585 count_after_second >= count_after_first,
586 "Extend count must be monotonically increasing"
587 );
588 }
589
590 #[kani::proof]
592 fn proof_measurement_count_matches() {
593 let mut verifier = BootChainVerifier::new();
594 let d: [u8; DIGEST_LEN] = kani::any();
595
596 let _ = verifier.measure(BootStage::Firmware, 0, d, MeasuredComponent::FirmwareCode);
597 assert!(verifier.verify_measurement_count().is_ok());
598 }
599}
600
601#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_pcr_initial_state() {
611 let pcr = PcrState::new();
612 let zero = [0u8; DIGEST_LEN];
613 assert_eq!(pcr.get(0), Some(&zero));
614 assert_eq!(pcr.get_extend_count(0), Some(0));
615 assert!(!pcr.is_extended(0));
616 }
617
618 #[test]
619 fn test_pcr_extend() {
620 let mut pcr = PcrState::new();
621 let digest = [0x42u8; DIGEST_LEN];
622 let result = pcr.extend(0, &digest);
623 assert!(result.is_ok());
624 assert!(pcr.is_extended(0));
625 assert_eq!(pcr.get_extend_count(0), Some(1));
626 }
627
628 #[test]
629 fn test_pcr_invalid_index() {
630 let mut pcr = PcrState::new();
631 let digest = [0x42u8; DIGEST_LEN];
632 let result = pcr.extend(MAX_PCRS, &digest);
633 assert_eq!(result, Err(BootVerifyError::InvalidPcrIndex));
634 assert_eq!(pcr.get(MAX_PCRS), None);
635 }
636
637 #[test]
638 fn test_pcr_extend_deterministic() {
639 let mut pcr1 = PcrState::new();
640 let mut pcr2 = PcrState::new();
641 let digest = [0xAB; DIGEST_LEN];
642 let r1 = pcr1.extend(0, &digest).unwrap();
643 let r2 = pcr2.extend(0, &digest).unwrap();
644 assert_eq!(r1, r2);
645 }
646
647 #[cfg(feature = "alloc")]
648 #[test]
649 fn test_measurement_log() {
650 let mut log = MeasurementLog::new();
651 assert!(log.is_empty());
652
653 log.add(0, [0x11; DIGEST_LEN], MeasuredComponent::FirmwareCode);
654 log.add(1, [0x22; DIGEST_LEN], MeasuredComponent::BootloaderBinary);
655
656 assert_eq!(log.len(), 2);
657 assert_eq!(log.entries()[0].sequence, 0);
658 assert_eq!(log.entries()[1].sequence, 1);
659 }
660
661 #[test]
662 fn test_boot_status_transitions() {
663 let mut v = BootChainVerifier::new();
664 assert_eq!(v.status(), BootStatus::NotStarted);
665
666 assert!(v.set_status(BootStatus::Measuring).is_ok());
667 assert!(v.set_status(BootStatus::Verifying).is_ok());
668 assert!(v.set_status(BootStatus::Approved).is_ok());
669
670 assert!(v.set_status(BootStatus::Measuring).is_err());
672 }
673
674 #[test]
675 fn test_invalid_status_transition() {
676 let mut v = BootChainVerifier::new();
677 assert!(v.set_status(BootStatus::Approved).is_err());
679 assert!(v.set_status(BootStatus::Rejected).is_err());
681 }
682
683 #[cfg(feature = "alloc")]
684 #[test]
685 fn test_full_boot_chain_verification() {
686 let mut v = BootChainVerifier::new();
687
688 let stages = [
690 (BootStage::Firmware, 0, MeasuredComponent::FirmwareCode),
691 (
692 BootStage::Bootloader,
693 1,
694 MeasuredComponent::BootloaderBinary,
695 ),
696 (BootStage::Kernel, 2, MeasuredComponent::KernelImage),
697 (BootStage::InitSystem, 3, MeasuredComponent::InitBinary),
698 (
699 BootStage::DriverFramework,
700 4,
701 MeasuredComponent::DriverBinary,
702 ),
703 (BootStage::UserSpace, 5, MeasuredComponent::UserComponent),
704 ];
705
706 for (i, (stage, pcr, component)) in stages.iter().enumerate() {
707 let mut digest = [0u8; DIGEST_LEN];
708 digest[0] = (i + 1) as u8;
709 v.measure(*stage, *pcr, digest, *component).unwrap();
710 }
711
712 assert!(v.verify_pcr_monotonicity().is_ok());
713 assert!(v.verify_measurement_completeness().is_ok());
714 assert!(v.verify_hash_chain().is_ok());
715 assert!(v.verify_log_ordering().is_ok());
716 assert!(v.verify_measurement_count().is_ok());
717 }
718
719 #[cfg(feature = "alloc")]
720 #[test]
721 fn test_incomplete_boot_chain() {
722 let mut v = BootChainVerifier::new();
723 let digest = [0x42u8; DIGEST_LEN];
724 v.measure(
725 BootStage::Firmware,
726 0,
727 digest,
728 MeasuredComponent::FirmwareCode,
729 )
730 .unwrap();
731
732 assert_eq!(
734 v.verify_measurement_completeness(),
735 Err(BootVerifyError::MissingMeasurement)
736 );
737 }
738}