1 // Copyright 2023, The Android Open Source Project 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //! Supports for the communication between rialto and host. 16 17 use crate::error::Result; 18 use ciborium_io::{Read, Write}; 19 use core::hint::spin_loop; 20 use core::mem; 21 use core::result; 22 use log::info; 23 use service_vm_comm::{Response, ServiceVmRequest}; 24 use tinyvec::ArrayVec; 25 use virtio_drivers::{ 26 self, 27 device::socket::{ 28 SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType, 29 }, 30 transport::Transport, 31 Hal, 32 }; 33 34 const WRITE_BUF_CAPACITY: usize = 512; 35 36 pub struct VsockStream<H: Hal, T: Transport> { 37 connection_manager: VsockConnectionManager<H, T>, 38 /// Peer address. The same port is used on rialto and peer for convenience. 39 peer_addr: VsockAddr, 40 write_buf: ArrayVec<[u8; WRITE_BUF_CAPACITY]>, 41 } 42 43 impl<H: Hal, T: Transport> VsockStream<H, T> { new( socket_device_driver: VirtIOSocket<H, T>, peer_addr: VsockAddr, ) -> virtio_drivers::Result<Self>44 pub fn new( 45 socket_device_driver: VirtIOSocket<H, T>, 46 peer_addr: VsockAddr, 47 ) -> virtio_drivers::Result<Self> { 48 let mut vsock_stream = Self { 49 connection_manager: VsockConnectionManager::new(socket_device_driver), 50 peer_addr, 51 write_buf: ArrayVec::default(), 52 }; 53 vsock_stream.connect()?; 54 Ok(vsock_stream) 55 } 56 connect(&mut self) -> virtio_drivers::Result57 fn connect(&mut self) -> virtio_drivers::Result { 58 self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?; 59 self.wait_for_connect()?; 60 info!("Connected to the peer {:?}", self.peer_addr); 61 Ok(()) 62 } 63 wait_for_connect(&mut self) -> virtio_drivers::Result64 fn wait_for_connect(&mut self) -> virtio_drivers::Result { 65 loop { 66 if let Some(event) = self.poll_event_from_peer()? { 67 match event { 68 VsockEventType::Connected => return Ok(()), 69 VsockEventType::Disconnected { .. } => { 70 return Err(SocketError::ConnectionFailed.into()) 71 } 72 // We shouldn't receive the following event before the connection is 73 // established. 74 VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => { 75 return Err(SocketError::InvalidOperation.into()) 76 } 77 // We can receive credit requests and updates at any time. 78 // This can be ignored as the connection manager handles them in poll(). 79 VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {} 80 } 81 } else { 82 spin_loop(); 83 } 84 } 85 } 86 read_request(&mut self) -> Result<ServiceVmRequest>87 pub fn read_request(&mut self) -> Result<ServiceVmRequest> { 88 Ok(ciborium::from_reader(self)?) 89 } 90 write_response(&mut self, response: &Response) -> Result<()>91 pub fn write_response(&mut self, response: &Response) -> Result<()> { 92 Ok(ciborium::into_writer(response, self)?) 93 } 94 95 /// Shuts down the data channel. shutdown(&mut self) -> virtio_drivers::Result96 pub fn shutdown(&mut self) -> virtio_drivers::Result { 97 self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?; 98 info!("Connection shutdown."); 99 Ok(()) 100 } 101 recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize>102 fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> { 103 let bytes_read = 104 self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)?; 105 106 let buffer_available_bytes = self 107 .connection_manager 108 .recv_buffer_available_bytes(self.peer_addr, self.peer_addr.port)?; 109 if buffer_available_bytes == 0 && bytes_read > 0 { 110 self.connection_manager.update_credit(self.peer_addr, self.peer_addr.port)?; 111 } 112 Ok(bytes_read) 113 } 114 wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result115 fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result { 116 const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error = 117 virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer); 118 loop { 119 match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) { 120 Ok(_) => return Ok(()), 121 Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => { 122 self.poll()?; 123 } 124 Err(e) => return Err(e), 125 } 126 } 127 } 128 wait_for_recv(&mut self) -> virtio_drivers::Result129 fn wait_for_recv(&mut self) -> virtio_drivers::Result { 130 loop { 131 match self.poll()? { 132 Some(VsockEventType::Received { .. }) => return Ok(()), 133 _ => spin_loop(), 134 } 135 } 136 } 137 138 /// Polls the rx queue after the connection is established with the peer, this function 139 /// rejects some invalid events. The valid events are handled inside the connection 140 /// manager. poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>>141 fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> { 142 if let Some(event) = self.poll_event_from_peer()? { 143 match event { 144 VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()), 145 VsockEventType::Connected | VsockEventType::ConnectionRequest => { 146 Err(SocketError::InvalidOperation.into()) 147 } 148 // When there is a received event, the received data is buffered in the 149 // connection manager's internal receive buffer, so we don't need to do 150 // anything here. 151 // The credit request and updates also handled inside the connection 152 // manager. 153 VsockEventType::Received { .. } 154 | VsockEventType::CreditRequest 155 | VsockEventType::CreditUpdate => Ok(Some(event)), 156 } 157 } else { 158 Ok(None) 159 } 160 } 161 poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>>162 fn poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> { 163 Ok(self.connection_manager.poll()?.map(|event| { 164 assert_eq!(event.source, self.peer_addr); 165 assert_eq!(event.destination.port, self.peer_addr.port); 166 event.event_type 167 })) 168 } 169 } 170 171 impl<H: Hal, T: Transport> Read for VsockStream<H, T> { 172 type Error = virtio_drivers::Error; 173 read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error>174 fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> { 175 let mut start = 0; 176 while start < data.len() { 177 let len = self.recv(&mut data[start..])?; 178 let len = if len == 0 { 179 self.wait_for_recv()?; 180 self.recv(&mut data[start..])? 181 } else { 182 len 183 }; 184 start += len; 185 } 186 Ok(()) 187 } 188 } 189 190 impl<H: Hal, T: Transport> Write for VsockStream<H, T> { 191 type Error = virtio_drivers::Error; 192 write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error>193 fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> { 194 if data.len() >= self.write_buf.capacity() - self.write_buf.len() { 195 self.flush()?; 196 if data.len() >= self.write_buf.capacity() { 197 self.wait_for_send(data)?; 198 return Ok(()); 199 } 200 } 201 self.write_buf.extend_from_slice(data); 202 Ok(()) 203 } 204 flush(&mut self) -> result::Result<(), Self::Error>205 fn flush(&mut self) -> result::Result<(), Self::Error> { 206 if !self.write_buf.is_empty() { 207 // We need to take the memory from self.write_buf to a temporary 208 // buffer to avoid borrowing `*self` as mutable and immutable on 209 // the same time in `self.wait_for_send(&self.write_buf)`. 210 let buffer = mem::take(&mut self.write_buf); 211 self.wait_for_send(&buffer)?; 212 } 213 Ok(()) 214 } 215 } 216