1#![allow(dead_code)]
9
10#[cfg(feature = "alloc")]
11extern crate alloc;
12
13#[cfg(feature = "alloc")]
14use alloc::vec::Vec;
15
16use super::VmError;
17
18const MAX_DEVICES_PER_GROUP: usize = 32;
24
25const MAX_GROUPS_PER_CONTAINER: usize = 64;
27
28const MAX_BAR_REGIONS: usize = 6;
30
31const MAX_DMA_MAPPINGS: usize = 256;
33
34const MAX_MSIX_VECTORS: usize = 2048;
36
37const MAX_IRQS: usize = 4;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
46pub struct PciAddress {
47 pub bus: u8,
49 pub device: u8,
51 pub function: u8,
53}
54
55impl PciAddress {
56 pub fn new(bus: u8, device: u8, function: u8) -> Self {
58 Self {
59 bus,
60 device: device & 0x1F,
61 function: function & 0x07,
62 }
63 }
64
65 pub fn to_bdf(&self) -> u16 {
67 ((self.bus as u16) << 8) | ((self.device as u16) << 3) | (self.function as u16)
68 }
69
70 pub fn from_bdf(bdf: u16) -> Self {
72 Self {
73 bus: (bdf >> 8) as u8,
74 device: ((bdf >> 3) & 0x1F) as u8,
75 function: (bdf & 0x07) as u8,
76 }
77 }
78}
79
80impl core::fmt::Display for PciAddress {
81 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82 write!(f, "{:02x}:{:02x}.{}", self.bus, self.device, self.function)
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
92pub struct BarFlags {
93 bits: u32,
94}
95
96impl BarFlags {
97 pub const IO: Self = Self { bits: 1 };
99 pub const MEMORY: Self = Self { bits: 2 };
101 pub const PREFETCHABLE: Self = Self { bits: 4 };
103 pub const BIT64: Self = Self { bits: 8 };
105
106 pub(crate) fn is_io(self) -> bool {
108 self.bits & 1 != 0
109 }
110
111 pub(crate) fn is_memory(self) -> bool {
113 self.bits & 2 != 0
114 }
115
116 pub(crate) fn is_prefetchable(self) -> bool {
118 self.bits & 4 != 0
119 }
120
121 pub(crate) fn is_64bit(self) -> bool {
123 self.bits & 8 != 0
124 }
125
126 pub(crate) fn union(self, other: Self) -> Self {
128 Self {
129 bits: self.bits | other.bits,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Copy)]
136pub struct BarRegion {
137 pub index: u8,
139 pub base_addr: u64,
141 pub size: u64,
143 pub flags: BarFlags,
145 pub mapped: bool,
147 pub guest_addr: u64,
149}
150
151impl BarRegion {
152 pub fn new(index: u8, base_addr: u64, size: u64, flags: BarFlags) -> Self {
154 Self {
155 index,
156 base_addr,
157 size,
158 flags,
159 mapped: false,
160 guest_addr: 0,
161 }
162 }
163
164 pub(crate) fn map_to_guest(&mut self, guest_addr: u64) {
166 self.guest_addr = guest_addr;
167 self.mapped = true;
168 }
169
170 pub(crate) fn unmap(&mut self) {
172 self.mapped = false;
173 self.guest_addr = 0;
174 }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
183pub struct DmaFlags {
184 bits: u32,
185}
186
187impl DmaFlags {
188 pub const READ: Self = Self { bits: 1 };
190 pub const WRITE: Self = Self { bits: 2 };
192 pub const READ_WRITE: Self = Self { bits: 3 };
194
195 pub(crate) fn is_readable(self) -> bool {
197 self.bits & 1 != 0
198 }
199
200 pub(crate) fn is_writable(self) -> bool {
202 self.bits & 2 != 0
203 }
204}
205
206#[derive(Debug, Clone, Copy)]
208pub struct DmaMapping {
209 pub iova: u64,
211 pub size: u64,
213 pub paddr: u64,
215 pub flags: DmaFlags,
217}
218
219impl DmaMapping {
220 pub fn new(iova: u64, size: u64, paddr: u64, flags: DmaFlags) -> Self {
222 Self {
223 iova,
224 size,
225 paddr,
226 flags,
227 }
228 }
229
230 pub(crate) fn contains(&self, iova: u64) -> bool {
232 iova >= self.iova && iova < self.iova + self.size
233 }
234
235 pub(crate) fn translate(&self, iova: u64) -> Option<u64> {
237 if self.contains(iova) {
238 Some(self.paddr + (iova - self.iova))
239 } else {
240 None
241 }
242 }
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub enum VfioIrqType {
252 Intx = 0,
254 Msi = 1,
256 MsiX = 2,
258 Err = 3,
260}
261
262#[derive(Debug, Clone, Copy)]
264pub struct VfioIrqInfo {
265 pub irq_type: VfioIrqType,
267 pub count: u32,
269 pub enabled: bool,
271 pub flags: u32,
273}
274
275impl VfioIrqInfo {
276 pub fn new(irq_type: VfioIrqType, count: u32) -> Self {
278 Self {
279 irq_type,
280 count,
281 enabled: false,
282 flags: 0,
283 }
284 }
285}
286
287#[cfg(feature = "alloc")]
293pub struct IommuGroup {
294 pub group_id: u32,
296 pub devices: Vec<PciAddress>,
298 pub attached: bool,
300 pub container_id: Option<u32>,
302}
303
304#[cfg(feature = "alloc")]
305impl IommuGroup {
306 pub fn new(group_id: u32) -> Self {
308 Self {
309 group_id,
310 devices: Vec::new(),
311 attached: false,
312 container_id: None,
313 }
314 }
315
316 pub(crate) fn add_device(&mut self, addr: PciAddress) -> Result<(), VmError> {
318 if self.devices.len() >= MAX_DEVICES_PER_GROUP {
319 return Err(VmError::DeviceError);
320 }
321 if self.devices.contains(&addr) {
322 return Err(VmError::DeviceError);
323 }
324 self.devices.push(addr);
325 Ok(())
326 }
327
328 pub(crate) fn remove_device(&mut self, addr: &PciAddress) -> bool {
330 if let Some(pos) = self.devices.iter().position(|d| d == addr) {
331 self.devices.swap_remove(pos);
332 true
333 } else {
334 false
335 }
336 }
337
338 pub(crate) fn contains_device(&self, addr: &PciAddress) -> bool {
340 self.devices.contains(addr)
341 }
342
343 pub(crate) fn attach(&mut self, container_id: u32) {
345 self.attached = true;
346 self.container_id = Some(container_id);
347 }
348
349 pub(crate) fn detach(&mut self) {
351 self.attached = false;
352 self.container_id = None;
353 }
354}
355
356#[cfg(feature = "alloc")]
362pub struct VfioContainer {
363 pub iommu_type: u32,
365 pub groups: Vec<IommuGroup>,
367 pub dma_mappings: Vec<DmaMapping>,
369 pub container_id: u32,
371}
372
373#[cfg(feature = "alloc")]
374impl VfioContainer {
375 pub fn new(container_id: u32, iommu_type: u32) -> Self {
377 Self {
378 iommu_type,
379 groups: Vec::new(),
380 dma_mappings: Vec::new(),
381 container_id,
382 }
383 }
384
385 pub(crate) fn add_group(&mut self, mut group: IommuGroup) -> Result<(), VmError> {
387 if self.groups.len() >= MAX_GROUPS_PER_CONTAINER {
388 return Err(VmError::DeviceError);
389 }
390 group.attach(self.container_id);
391 self.groups.push(group);
392 Ok(())
393 }
394
395 pub(crate) fn dma_map(&mut self, mapping: DmaMapping) -> Result<(), VmError> {
397 if self.dma_mappings.len() >= MAX_DMA_MAPPINGS {
398 return Err(VmError::DeviceError);
399 }
400 for existing in &self.dma_mappings {
402 if mapping.iova < existing.iova + existing.size
403 && mapping.iova + mapping.size > existing.iova
404 {
405 return Err(VmError::DeviceError);
406 }
407 }
408 self.dma_mappings.push(mapping);
409 Ok(())
410 }
411
412 pub(crate) fn dma_unmap(&mut self, iova: u64) -> Result<u64, VmError> {
414 if let Some(pos) = self.dma_mappings.iter().position(|m| m.iova == iova) {
415 let size = self.dma_mappings[pos].size;
416 self.dma_mappings.swap_remove(pos);
417 Ok(size)
418 } else {
419 Err(VmError::DeviceError)
420 }
421 }
422
423 pub(crate) fn translate_iova(&self, iova: u64) -> Option<u64> {
425 for mapping in &self.dma_mappings {
426 if let Some(paddr) = mapping.translate(iova) {
427 return Some(paddr);
428 }
429 }
430 None
431 }
432
433 pub(crate) fn group_count(&self) -> usize {
435 self.groups.len()
436 }
437
438 pub(crate) fn dma_mapping_count(&self) -> usize {
440 self.dma_mappings.len()
441 }
442}
443
444#[cfg(feature = "alloc")]
450pub struct VfioDevice {
451 pub group_id: u32,
453 pub pci_address: PciAddress,
455 pub bar_regions: Vec<BarRegion>,
457 pub irqs: Vec<VfioIrqInfo>,
459 pub opened: bool,
461 pub vendor_id: u16,
463 pub device_id: u16,
465 pub assigned_vm: Option<u32>,
467}
468
469#[cfg(feature = "alloc")]
470impl VfioDevice {
471 pub fn open(
473 group_id: u32,
474 pci_address: PciAddress,
475 vendor_id: u16,
476 device_id: u16,
477 ) -> Result<Self, VmError> {
478 Ok(Self {
479 group_id,
480 pci_address,
481 bar_regions: Vec::new(),
482 irqs: Vec::new(),
483 opened: true,
484 vendor_id,
485 device_id,
486 assigned_vm: None,
487 })
488 }
489
490 pub(crate) fn add_bar(&mut self, region: BarRegion) -> Result<(), VmError> {
492 if self.bar_regions.len() >= MAX_BAR_REGIONS {
493 return Err(VmError::DeviceError);
494 }
495 self.bar_regions.push(region);
496 Ok(())
497 }
498
499 pub(crate) fn map_bar(&mut self, bar_index: u8, guest_addr: u64) -> Result<(), VmError> {
501 if let Some(bar) = self.bar_regions.iter_mut().find(|b| b.index == bar_index) {
502 bar.map_to_guest(guest_addr);
503 Ok(())
506 } else {
507 Err(VmError::DeviceError)
508 }
509 }
510
511 pub(crate) fn unmap_bar(&mut self, bar_index: u8) -> Result<(), VmError> {
513 if let Some(bar) = self.bar_regions.iter_mut().find(|b| b.index == bar_index) {
514 bar.unmap();
515 Ok(())
516 } else {
517 Err(VmError::DeviceError)
518 }
519 }
520
521 pub(crate) fn enable_msix(&mut self, num_vectors: u32) -> Result<(), VmError> {
523 if num_vectors > MAX_MSIX_VECTORS as u32 {
524 return Err(VmError::DeviceError);
525 }
526
527 self.irqs.retain(|i| i.irq_type != VfioIrqType::MsiX);
529
530 let mut irq = VfioIrqInfo::new(VfioIrqType::MsiX, num_vectors);
531 irq.enabled = true;
532 self.irqs.push(irq);
533 Ok(())
534 }
535
536 pub(crate) fn disable_msix(&mut self) {
538 if let Some(irq) = self
539 .irqs
540 .iter_mut()
541 .find(|i| i.irq_type == VfioIrqType::MsiX)
542 {
543 irq.enabled = false;
544 }
545 }
546
547 pub(crate) fn reset(&mut self) -> Result<(), VmError> {
549 for bar in &mut self.bar_regions {
551 bar.unmap();
552 }
553 for irq in &mut self.irqs {
555 irq.enabled = false;
556 }
557 Ok(())
558 }
559
560 pub(crate) fn assign_to_vm(&mut self, vm_id: u32) -> Result<(), VmError> {
562 if self.assigned_vm.is_some() {
563 return Err(VmError::DeviceError);
564 }
565 self.assigned_vm = Some(vm_id);
566 Ok(())
567 }
568
569 pub(crate) fn unassign(&mut self) {
571 self.assigned_vm = None;
572 }
573
574 pub(crate) fn is_assigned(&self) -> bool {
576 self.assigned_vm.is_some()
577 }
578
579 pub(crate) fn assigned_vm_id(&self) -> Option<u32> {
581 self.assigned_vm
582 }
583
584 pub(crate) fn bar(&self, index: u8) -> Option<&BarRegion> {
586 self.bar_regions.iter().find(|b| b.index == index)
587 }
588
589 pub(crate) fn msix_enabled(&self) -> bool {
591 self.irqs
592 .iter()
593 .any(|i| i.irq_type == VfioIrqType::MsiX && i.enabled)
594 }
595
596 pub(crate) fn msix_vector_count(&self) -> u32 {
598 self.irqs
599 .iter()
600 .find(|i| i.irq_type == VfioIrqType::MsiX)
601 .map_or(0, |i| i.count)
602 }
603}
604
605#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn test_pci_address_new() {
615 let addr = PciAddress::new(0, 31, 7);
616 assert_eq!(addr.bus, 0);
617 assert_eq!(addr.device, 31);
618 assert_eq!(addr.function, 7);
619 }
620
621 #[test]
622 fn test_pci_address_bdf_roundtrip() {
623 let addr = PciAddress::new(2, 3, 1);
624 let bdf = addr.to_bdf();
625 let decoded = PciAddress::from_bdf(bdf);
626 assert_eq!(decoded, addr);
627 }
628
629 #[test]
630 fn test_pci_address_mask() {
631 let addr = PciAddress::new(0, 0xFF, 0xFF);
632 assert_eq!(addr.device, 0x1F);
633 assert_eq!(addr.function, 0x07);
634 }
635
636 #[test]
637 fn test_bar_region_map() {
638 let mut bar = BarRegion::new(0, 0xFE00_0000, 0x1000, BarFlags::MEMORY);
639 assert!(!bar.mapped);
640 bar.map_to_guest(0xC000_0000);
641 assert!(bar.mapped);
642 assert_eq!(bar.guest_addr, 0xC000_0000);
643 }
644
645 #[test]
646 fn test_bar_region_unmap() {
647 let mut bar = BarRegion::new(0, 0xFE00_0000, 0x1000, BarFlags::MEMORY);
648 bar.map_to_guest(0xC000_0000);
649 bar.unmap();
650 assert!(!bar.mapped);
651 }
652
653 #[test]
654 fn test_bar_flags() {
655 let flags = BarFlags::MEMORY.union(BarFlags::PREFETCHABLE);
656 assert!(flags.is_memory());
657 assert!(flags.is_prefetchable());
658 assert!(!flags.is_io());
659 }
660
661 #[test]
662 fn test_dma_mapping_translate() {
663 let mapping = DmaMapping::new(0x1000, 0x2000, 0x8000_0000, DmaFlags::READ_WRITE);
664 assert_eq!(mapping.translate(0x1000), Some(0x8000_0000));
665 assert_eq!(mapping.translate(0x2000), Some(0x8000_1000));
666 assert_eq!(mapping.translate(0x3000), None); assert_eq!(mapping.translate(0x0FFF), None);
668 }
669
670 #[test]
671 fn test_dma_mapping_contains() {
672 let mapping = DmaMapping::new(0x1000, 0x2000, 0, DmaFlags::READ);
673 assert!(mapping.contains(0x1000));
674 assert!(mapping.contains(0x2FFF));
675 assert!(!mapping.contains(0x3000));
676 }
677
678 #[test]
679 fn test_iommu_group() {
680 let mut group = IommuGroup::new(1);
681 let addr1 = PciAddress::new(0, 1, 0);
682 let addr2 = PciAddress::new(0, 2, 0);
683 group.add_device(addr1).unwrap();
684 group.add_device(addr2).unwrap();
685 assert_eq!(group.devices.len(), 2);
686 assert!(group.contains_device(&addr1));
687 }
688
689 #[test]
690 fn test_iommu_group_duplicate_device() {
691 let mut group = IommuGroup::new(1);
692 let addr = PciAddress::new(0, 1, 0);
693 group.add_device(addr).unwrap();
694 assert!(group.add_device(addr).is_err());
695 }
696
697 #[test]
698 fn test_iommu_group_remove_device() {
699 let mut group = IommuGroup::new(1);
700 let addr = PciAddress::new(0, 1, 0);
701 group.add_device(addr).unwrap();
702 assert!(group.remove_device(&addr));
703 assert!(!group.contains_device(&addr));
704 }
705
706 #[test]
707 fn test_iommu_group_attach_detach() {
708 let mut group = IommuGroup::new(1);
709 group.attach(42);
710 assert!(group.attached);
711 assert_eq!(group.container_id, Some(42));
712 group.detach();
713 assert!(!group.attached);
714 assert_eq!(group.container_id, None);
715 }
716
717 #[test]
718 fn test_vfio_container() {
719 let mut container = VfioContainer::new(1, 1);
720 let group = IommuGroup::new(1);
721 container.add_group(group).unwrap();
722 assert_eq!(container.group_count(), 1);
723 }
724
725 #[test]
726 fn test_vfio_container_dma() {
727 let mut container = VfioContainer::new(1, 1);
728 let mapping = DmaMapping::new(0x1000, 0x2000, 0x8000_0000, DmaFlags::READ_WRITE);
729 container.dma_map(mapping).unwrap();
730 assert_eq!(container.dma_mapping_count(), 1);
731 assert_eq!(container.translate_iova(0x1500), Some(0x8000_0500));
732 }
733
734 #[test]
735 fn test_vfio_container_dma_overlap() {
736 let mut container = VfioContainer::new(1, 1);
737 let m1 = DmaMapping::new(0x1000, 0x2000, 0, DmaFlags::READ);
738 let m2 = DmaMapping::new(0x2000, 0x1000, 0, DmaFlags::READ); container.dma_map(m1).unwrap();
740 assert!(container.dma_map(m2).is_err());
741 }
742
743 #[test]
744 fn test_vfio_container_dma_unmap() {
745 let mut container = VfioContainer::new(1, 1);
746 container
747 .dma_map(DmaMapping::new(0x1000, 0x2000, 0, DmaFlags::READ))
748 .unwrap();
749 let size = container.dma_unmap(0x1000).unwrap();
750 assert_eq!(size, 0x2000);
751 assert_eq!(container.dma_mapping_count(), 0);
752 }
753
754 #[test]
755 fn test_vfio_device_open() {
756 let dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
757 assert!(dev.opened);
758 assert_eq!(dev.vendor_id, 0x8086);
759 assert_eq!(dev.device_id, 0x1234);
760 }
761
762 #[test]
763 fn test_vfio_device_bar() {
764 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
765 dev.add_bar(BarRegion::new(0, 0xFE00_0000, 0x10000, BarFlags::MEMORY))
766 .unwrap();
767 dev.map_bar(0, 0xC000_0000).unwrap();
768 let bar = dev.bar(0).unwrap();
769 assert!(bar.mapped);
770 assert_eq!(bar.guest_addr, 0xC000_0000);
771 }
772
773 #[test]
774 fn test_vfio_device_msix() {
775 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
776 dev.enable_msix(16).unwrap();
777 assert!(dev.msix_enabled());
778 assert_eq!(dev.msix_vector_count(), 16);
779 dev.disable_msix();
780 assert!(!dev.msix_enabled());
781 }
782
783 #[test]
784 fn test_vfio_device_reset() {
785 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
786 dev.add_bar(BarRegion::new(0, 0xFE00_0000, 0x10000, BarFlags::MEMORY))
787 .unwrap();
788 dev.map_bar(0, 0xC000_0000).unwrap();
789 dev.enable_msix(4).unwrap();
790 dev.reset().unwrap();
791 assert!(!dev.bar(0).unwrap().mapped);
792 assert!(!dev.msix_enabled());
793 }
794
795 #[test]
796 fn test_vfio_device_assign() {
797 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
798 assert!(!dev.is_assigned());
799 dev.assign_to_vm(1).unwrap();
800 assert!(dev.is_assigned());
801 assert_eq!(dev.assigned_vm_id(), Some(1));
802 assert!(dev.assign_to_vm(2).is_err());
804 dev.unassign();
805 assert!(!dev.is_assigned());
806 }
807
808 #[test]
809 fn test_vfio_device_unmap_bar() {
810 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
811 dev.add_bar(BarRegion::new(0, 0xFE00_0000, 0x10000, BarFlags::MEMORY))
812 .unwrap();
813 dev.map_bar(0, 0xC000_0000).unwrap();
814 dev.unmap_bar(0).unwrap();
815 assert!(!dev.bar(0).unwrap().mapped);
816 }
817
818 #[test]
819 fn test_vfio_device_bar_not_found() {
820 let mut dev = VfioDevice::open(1, PciAddress::new(0, 3, 0), 0x8086, 0x1234).unwrap();
821 assert!(dev.map_bar(0, 0).is_err());
822 }
823}