1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Functions to scan the PCI bus for VirtIO device.
16 
17 use aarch64_paging::paging::MemoryRegion;
18 use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error, Layout};
19 use core::{mem::size_of, ptr::NonNull};
20 use fdtpci::PciInfo;
21 use log::{debug, info};
22 use virtio_drivers::{
23     device::console::VirtIOConsole,
24     transport::{
25         pci::{bus::PciRoot, PciTransport},
26         DeviceType, Transport,
27     },
28     BufferDirection, Error, Hal, PhysAddr, PAGE_SIZE,
29 };
30 use vmbase::virtio::pci::{self, PciTransportIterator};
31 
32 /// The standard sector size of a VirtIO block device, in bytes.
33 const SECTOR_SIZE_BYTES: usize = 512;
34 
35 /// The size in sectors of the test block device we expect.
36 const EXPECTED_SECTOR_COUNT: usize = 4;
37 
check_pci(pci_root: &mut PciRoot)38 pub fn check_pci(pci_root: &mut PciRoot) {
39     let mut checked_virtio_device_count = 0;
40     let mut block_device_count = 0;
41     let mut socket_device_count = 0;
42     for mut transport in PciTransportIterator::<HalImpl>::new(pci_root) {
43         info!(
44             "Detected virtio PCI device with device type {:?}, features {:#018x}",
45             transport.device_type(),
46             transport.read_device_features(),
47         );
48         match transport.device_type() {
49             DeviceType::Block => {
50                 check_virtio_block_device(transport, block_device_count);
51                 block_device_count += 1;
52                 checked_virtio_device_count += 1;
53             }
54             DeviceType::Console => {
55                 check_virtio_console_device(transport);
56                 checked_virtio_device_count += 1;
57             }
58             DeviceType::Socket => {
59                 check_virtio_socket_device(transport);
60                 socket_device_count += 1;
61                 checked_virtio_device_count += 1;
62             }
63             _ => {}
64         }
65     }
66 
67     assert_eq!(checked_virtio_device_count, 6);
68     assert_eq!(block_device_count, 2);
69     assert_eq!(socket_device_count, 1);
70 }
71 
72 /// Checks the given VirtIO block device.
check_virtio_block_device(transport: PciTransport, index: usize)73 fn check_virtio_block_device(transport: PciTransport, index: usize) {
74     let mut blk = pci::VirtIOBlk::<HalImpl>::new(transport).expect("failed to create blk driver");
75     info!("Found {} KiB block device.", blk.capacity() * SECTOR_SIZE_BYTES as u64 / 1024);
76     match index {
77         0 => {
78             assert_eq!(blk.capacity(), EXPECTED_SECTOR_COUNT as u64);
79             let mut data = [0; SECTOR_SIZE_BYTES * EXPECTED_SECTOR_COUNT];
80             for i in 0..EXPECTED_SECTOR_COUNT {
81                 blk.read_blocks(i, &mut data[i * SECTOR_SIZE_BYTES..(i + 1) * SECTOR_SIZE_BYTES])
82                     .expect("Failed to read block device.");
83             }
84             for (i, chunk) in data.chunks(size_of::<u32>()).enumerate() {
85                 assert_eq!(chunk, &(i as u32).to_le_bytes());
86             }
87             info!("Read expected data from block device.");
88         }
89         1 => {
90             assert_eq!(blk.capacity(), 0);
91             let mut data = [0; SECTOR_SIZE_BYTES];
92             assert_eq!(blk.read_blocks(0, &mut data), Err(Error::IoError));
93         }
94         _ => panic!("Unexpected VirtIO block device index {}.", index),
95     }
96 }
97 
98 /// Checks the given VirtIO socket device.
check_virtio_socket_device(transport: PciTransport)99 fn check_virtio_socket_device(transport: PciTransport) {
100     let socket = pci::VirtIOSocket::<HalImpl>::new(transport)
101         .expect("Failed to create VirtIO socket driver");
102     info!("Found socket device: guest_cid={}", socket.guest_cid());
103 }
104 
105 /// Checks the given VirtIO console device.
check_virtio_console_device(transport: PciTransport)106 fn check_virtio_console_device(transport: PciTransport) {
107     let mut console = VirtIOConsole::<HalImpl, PciTransport>::new(transport)
108         .expect("Failed to create VirtIO console driver");
109     info!("Found console device: {:?}", console.info());
110     for &c in b"Hello VirtIO console\n" {
111         console.send(c).expect("Failed to send character to VirtIO console device");
112     }
113     info!("Wrote to VirtIO console.");
114 }
115 
116 /// Gets the memory region in which BARs are allocated.
get_bar_region(pci_info: &PciInfo) -> MemoryRegion117 pub fn get_bar_region(pci_info: &PciInfo) -> MemoryRegion {
118     MemoryRegion::new(pci_info.bar_range.start as usize, pci_info.bar_range.end as usize)
119 }
120 
121 struct HalImpl;
122 
123 /// SAFETY: See the 'Implementation Safety' comments on methods below for how they fulfill the
124 /// safety requirements of the unsafe `Hal` trait.
125 unsafe impl Hal for HalImpl {
126     /// # Implementation Safety
127     ///
128     /// `dma_alloc` ensures the returned DMA buffer is not aliased with any other allocation or
129     /// reference in the program until it is deallocated by `dma_dealloc` by allocating a unique
130     /// block of memory using `alloc_zeroed`, which is guaranteed to allocate valid, unique and
131     /// zeroed memory. We request an alignment of at least `PAGE_SIZE` from `alloc_zeroed`.
dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>)132     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
133         debug!("dma_alloc: pages={}", pages);
134         let layout =
135             Layout::from_size_align(pages.checked_mul(PAGE_SIZE).unwrap(), PAGE_SIZE).unwrap();
136         assert_ne!(layout.size(), 0);
137         // SAFETY: We just checked that the layout has a non-zero size.
138         let vaddr = unsafe { alloc_zeroed(layout) };
139         let vaddr =
140             if let Some(vaddr) = NonNull::new(vaddr) { vaddr } else { handle_alloc_error(layout) };
141         let paddr = virt_to_phys(vaddr);
142         (paddr, vaddr)
143     }
144 
dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32145     unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
146         debug!("dma_dealloc: paddr={:#x}, pages={}", paddr, pages);
147         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
148         // SAFETY: The memory was allocated by `dma_alloc` above using the same allocator, and the
149         // layout is the same as was used then.
150         unsafe {
151             dealloc(vaddr.as_ptr(), layout);
152         }
153         0
154     }
155 
156     /// # Implementation Safety
157     ///
158     /// The returned pointer must be valid because the `paddr` describes a valid MMIO region, and we
159     /// previously mapped the entire PCI MMIO range. It can't alias any other allocations because
160     /// the PCI MMIO range doesn't overlap with any other memory ranges.
mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8>161     unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
162         NonNull::new(paddr as _).unwrap()
163     }
164 
share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr165     unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
166         let vaddr = buffer.cast();
167         // Nothing to do, as the host already has access to all memory.
168         virt_to_phys(vaddr)
169     }
170 
unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection)171     unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
172         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
173         // anywhere else.
174     }
175 }
176 
virt_to_phys(vaddr: NonNull<u8>) -> PhysAddr177 fn virt_to_phys(vaddr: NonNull<u8>) -> PhysAddr {
178     vaddr.as_ptr() as _
179 }
180