1 // Copyright 2023, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Supports for the communication between rialto and host.
16 
17 use crate::error::Result;
18 use ciborium_io::{Read, Write};
19 use core::hint::spin_loop;
20 use core::mem;
21 use core::result;
22 use log::info;
23 use service_vm_comm::{Response, ServiceVmRequest};
24 use tinyvec::ArrayVec;
25 use virtio_drivers::{
26     self,
27     device::socket::{
28         SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType,
29     },
30     transport::Transport,
31     Hal,
32 };
33 
34 const WRITE_BUF_CAPACITY: usize = 512;
35 
36 pub struct VsockStream<H: Hal, T: Transport> {
37     connection_manager: VsockConnectionManager<H, T>,
38     /// Peer address. The same port is used on rialto and peer for convenience.
39     peer_addr: VsockAddr,
40     write_buf: ArrayVec<[u8; WRITE_BUF_CAPACITY]>,
41 }
42 
43 impl<H: Hal, T: Transport> VsockStream<H, T> {
new( socket_device_driver: VirtIOSocket<H, T>, peer_addr: VsockAddr, ) -> virtio_drivers::Result<Self>44     pub fn new(
45         socket_device_driver: VirtIOSocket<H, T>,
46         peer_addr: VsockAddr,
47     ) -> virtio_drivers::Result<Self> {
48         let mut vsock_stream = Self {
49             connection_manager: VsockConnectionManager::new(socket_device_driver),
50             peer_addr,
51             write_buf: ArrayVec::default(),
52         };
53         vsock_stream.connect()?;
54         Ok(vsock_stream)
55     }
56 
connect(&mut self) -> virtio_drivers::Result57     fn connect(&mut self) -> virtio_drivers::Result {
58         self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?;
59         self.wait_for_connect()?;
60         info!("Connected to the peer {:?}", self.peer_addr);
61         Ok(())
62     }
63 
wait_for_connect(&mut self) -> virtio_drivers::Result64     fn wait_for_connect(&mut self) -> virtio_drivers::Result {
65         loop {
66             if let Some(event) = self.poll_event_from_peer()? {
67                 match event {
68                     VsockEventType::Connected => return Ok(()),
69                     VsockEventType::Disconnected { .. } => {
70                         return Err(SocketError::ConnectionFailed.into())
71                     }
72                     // We shouldn't receive the following event before the connection is
73                     // established.
74                     VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => {
75                         return Err(SocketError::InvalidOperation.into())
76                     }
77                     // We can receive credit requests and updates at any time.
78                     // This can be ignored as the connection manager handles them in poll().
79                     VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {}
80                 }
81             } else {
82                 spin_loop();
83             }
84         }
85     }
86 
read_request(&mut self) -> Result<ServiceVmRequest>87     pub fn read_request(&mut self) -> Result<ServiceVmRequest> {
88         Ok(ciborium::from_reader(self)?)
89     }
90 
write_response(&mut self, response: &Response) -> Result<()>91     pub fn write_response(&mut self, response: &Response) -> Result<()> {
92         Ok(ciborium::into_writer(response, self)?)
93     }
94 
95     /// Shuts down the data channel.
shutdown(&mut self) -> virtio_drivers::Result96     pub fn shutdown(&mut self) -> virtio_drivers::Result {
97         self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?;
98         info!("Connection shutdown.");
99         Ok(())
100     }
101 
recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize>102     fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
103         let bytes_read =
104             self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)?;
105 
106         let buffer_available_bytes = self
107             .connection_manager
108             .recv_buffer_available_bytes(self.peer_addr, self.peer_addr.port)?;
109         if buffer_available_bytes == 0 && bytes_read > 0 {
110             self.connection_manager.update_credit(self.peer_addr, self.peer_addr.port)?;
111         }
112         Ok(bytes_read)
113     }
114 
wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result115     fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result {
116         const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error =
117             virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer);
118         loop {
119             match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) {
120                 Ok(_) => return Ok(()),
121                 Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => {
122                     self.poll()?;
123                 }
124                 Err(e) => return Err(e),
125             }
126         }
127     }
128 
wait_for_recv(&mut self) -> virtio_drivers::Result129     fn wait_for_recv(&mut self) -> virtio_drivers::Result {
130         loop {
131             match self.poll()? {
132                 Some(VsockEventType::Received { .. }) => return Ok(()),
133                 _ => spin_loop(),
134             }
135         }
136     }
137 
138     /// Polls the rx queue after the connection is established with the peer, this function
139     /// rejects some invalid events. The valid events are handled inside the connection
140     /// manager.
poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>>141     fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
142         if let Some(event) = self.poll_event_from_peer()? {
143             match event {
144                 VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
145                 VsockEventType::Connected | VsockEventType::ConnectionRequest => {
146                     Err(SocketError::InvalidOperation.into())
147                 }
148                 // When there is a received event, the received data is buffered in the
149                 // connection manager's internal receive buffer, so we don't need to do
150                 // anything here.
151                 // The credit request and updates also handled inside the connection
152                 // manager.
153                 VsockEventType::Received { .. }
154                 | VsockEventType::CreditRequest
155                 | VsockEventType::CreditUpdate => Ok(Some(event)),
156             }
157         } else {
158             Ok(None)
159         }
160     }
161 
poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>>162     fn poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
163         Ok(self.connection_manager.poll()?.map(|event| {
164             assert_eq!(event.source, self.peer_addr);
165             assert_eq!(event.destination.port, self.peer_addr.port);
166             event.event_type
167         }))
168     }
169 }
170 
171 impl<H: Hal, T: Transport> Read for VsockStream<H, T> {
172     type Error = virtio_drivers::Error;
173 
read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error>174     fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> {
175         let mut start = 0;
176         while start < data.len() {
177             let len = self.recv(&mut data[start..])?;
178             let len = if len == 0 {
179                 self.wait_for_recv()?;
180                 self.recv(&mut data[start..])?
181             } else {
182                 len
183             };
184             start += len;
185         }
186         Ok(())
187     }
188 }
189 
190 impl<H: Hal, T: Transport> Write for VsockStream<H, T> {
191     type Error = virtio_drivers::Error;
192 
write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error>193     fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> {
194         if data.len() >= self.write_buf.capacity() - self.write_buf.len() {
195             self.flush()?;
196             if data.len() >= self.write_buf.capacity() {
197                 self.wait_for_send(data)?;
198                 return Ok(());
199             }
200         }
201         self.write_buf.extend_from_slice(data);
202         Ok(())
203     }
204 
flush(&mut self) -> result::Result<(), Self::Error>205     fn flush(&mut self) -> result::Result<(), Self::Error> {
206         if !self.write_buf.is_empty() {
207             // We need to take the memory from self.write_buf to a temporary
208             // buffer to avoid borrowing `*self` as mutable and immutable on
209             // the same time in `self.wait_for_send(&self.write_buf)`.
210             let buffer = mem::take(&mut self.write_buf);
211             self.wait_for_send(&buffer)?;
212         }
213         Ok(())
214     }
215 }
216