1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io::{self, IoSlice};
6 use std::marker::PhantomData;
7 use std::ops::Deref;
8 use std::os::unix::prelude::{AsRawFd, RawFd};
9 use std::time::Duration;
10 
11 use crate::{net::UnixSeqpacket, FromRawDescriptor, SafeDescriptor, ScmSocket, UnsyncMarker};
12 
13 use cros_async::{Executor, IntoAsync, IoSourceExt};
14 use serde::{de::DeserializeOwned, Serialize};
15 use sys_util::{
16     deserialize_with_descriptors, AsRawDescriptor, RawDescriptor, SerializeDescriptors,
17 };
18 use thiserror::Error as ThisError;
19 
20 #[derive(ThisError, Debug)]
21 pub enum Error {
22     #[error("failed to serialize/deserialize json from packet: {0}")]
23     Json(serde_json::Error),
24     #[error("failed to send packet: {0}")]
25     Send(sys_util::Error),
26     #[error("failed to receive packet: {0}")]
27     Recv(io::Error),
28     #[error("tube was disconnected")]
29     Disconnected,
30     #[error("failed to crate tube pair: {0}")]
31     Pair(io::Error),
32     #[error("failed to set send timeout: {0}")]
33     SetSendTimeout(io::Error),
34     #[error("failed to set recv timeout: {0}")]
35     SetRecvTimeout(io::Error),
36     #[error("failed to create async tube: {0}")]
37     CreateAsync(cros_async::AsyncError),
38 }
39 
40 pub type Result<T> = std::result::Result<T, Error>;
41 
42 /// Bidirectional tube that support both send and recv.
43 pub struct Tube {
44     socket: UnixSeqpacket,
45     _unsync_marker: UnsyncMarker,
46 }
47 
48 impl Tube {
49     /// Create a pair of connected tubes. Request is send in one direction while response is in the
50     /// other direction.
pair() -> Result<(Tube, Tube)>51     pub fn pair() -> Result<(Tube, Tube)> {
52         let (socket1, socket2) = UnixSeqpacket::pair().map_err(Error::Pair)?;
53         let tube1 = Tube::new(socket1);
54         let tube2 = Tube::new(socket2);
55         Ok((tube1, tube2))
56     }
57 
58     // Create a new `Tube`.
new(socket: UnixSeqpacket) -> Tube59     pub fn new(socket: UnixSeqpacket) -> Tube {
60         Tube {
61             socket,
62             _unsync_marker: PhantomData,
63         }
64     }
65 
into_async_tube(self, ex: &Executor) -> Result<AsyncTube>66     pub fn into_async_tube(self, ex: &Executor) -> Result<AsyncTube> {
67         let inner = ex.async_from(self).map_err(Error::CreateAsync)?;
68         Ok(AsyncTube { inner })
69     }
70 
send<T: Serialize>(&self, msg: &T) -> Result<()>71     pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> {
72         let msg_serialize = SerializeDescriptors::new(&msg);
73         let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?;
74         let msg_descriptors = msg_serialize.into_descriptors();
75 
76         self.socket
77             .send_with_fds(&[IoSlice::new(&msg_json)], &msg_descriptors)
78             .map_err(Error::Send)?;
79         Ok(())
80     }
81 
recv<T: DeserializeOwned>(&self) -> Result<T>82     pub fn recv<T: DeserializeOwned>(&self) -> Result<T> {
83         let (msg_json, msg_descriptors) =
84             self.socket.recv_as_vec_with_fds().map_err(Error::Recv)?;
85 
86         if msg_json.is_empty() {
87             return Err(Error::Disconnected);
88         }
89 
90         let mut msg_descriptors_safe = msg_descriptors
91             .into_iter()
92             .map(|v| {
93                 Some(unsafe {
94                     // Safe because the socket returns new fds that are owned locally by this scope.
95                     SafeDescriptor::from_raw_descriptor(v)
96                 })
97             })
98             .collect();
99 
100         deserialize_with_descriptors(
101             || serde_json::from_slice(&msg_json),
102             &mut msg_descriptors_safe,
103         )
104         .map_err(Error::Json)
105     }
106 
107     /// Returns true if there is a packet ready to `recv` without blocking.
108     ///
109     /// If there is an error trying to determine if there is a packet ready, this returns false.
is_packet_ready(&self) -> bool110     pub fn is_packet_ready(&self) -> bool {
111         self.socket.get_readable_bytes().unwrap_or(0) > 0
112     }
113 
set_send_timeout(&self, timeout: Option<Duration>) -> Result<()>114     pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> {
115         self.socket
116             .set_write_timeout(timeout)
117             .map_err(Error::SetSendTimeout)
118     }
119 
set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()>120     pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> {
121         self.socket
122             .set_read_timeout(timeout)
123             .map_err(Error::SetRecvTimeout)
124     }
125 }
126 
127 impl AsRawDescriptor for Tube {
as_raw_descriptor(&self) -> RawDescriptor128     fn as_raw_descriptor(&self) -> RawDescriptor {
129         self.socket.as_raw_descriptor()
130     }
131 }
132 
133 impl AsRawFd for Tube {
as_raw_fd(&self) -> RawFd134     fn as_raw_fd(&self) -> RawFd {
135         self.socket.as_raw_fd()
136     }
137 }
138 
139 impl IntoAsync for Tube {}
140 
141 pub struct AsyncTube {
142     inner: Box<dyn IoSourceExt<Tube>>,
143 }
144 
145 impl AsyncTube {
next<T: DeserializeOwned>(&self) -> Result<T>146     pub async fn next<T: DeserializeOwned>(&self) -> Result<T> {
147         self.inner.wait_readable().await.unwrap();
148         self.inner.as_source().recv()
149     }
150 }
151 
152 impl Deref for AsyncTube {
153     type Target = Tube;
154 
deref(&self) -> &Self::Target155     fn deref(&self) -> &Self::Target {
156         self.inner.as_source()
157     }
158 }
159 
160 impl Into<Tube> for AsyncTube {
into(self) -> Tube161     fn into(self) -> Tube {
162         self.inner.into_source()
163     }
164 }
165 
166 #[cfg(test)]
167 mod tests {
168     use super::*;
169     use crate::Event;
170 
171     use std::collections::HashMap;
172     use std::time::Duration;
173 
174     use serde::{Deserialize, Serialize};
175 
176     #[track_caller]
test_event_pair(send: Event, mut recv: Event)177     fn test_event_pair(send: Event, mut recv: Event) {
178         send.write(1).unwrap();
179         recv.read_timeout(Duration::from_secs(1)).unwrap();
180     }
181 
182     #[test]
send_recv_no_fd()183     fn send_recv_no_fd() {
184         let (s1, s2) = Tube::pair().unwrap();
185 
186         let test_msg = "hello world";
187         s1.send(&test_msg).unwrap();
188         let recv_msg: String = s2.recv().unwrap();
189 
190         assert_eq!(test_msg, recv_msg);
191     }
192 
193     #[test]
send_recv_one_fd()194     fn send_recv_one_fd() {
195         #[derive(Serialize, Deserialize)]
196         struct EventStruct {
197             x: u32,
198             b: Event,
199         }
200 
201         let (s1, s2) = Tube::pair().unwrap();
202 
203         let test_msg = EventStruct {
204             x: 100,
205             b: Event::new().unwrap(),
206         };
207         s1.send(&test_msg).unwrap();
208         let recv_msg: EventStruct = s2.recv().unwrap();
209 
210         assert_eq!(test_msg.x, recv_msg.x);
211 
212         test_event_pair(test_msg.b, recv_msg.b);
213     }
214 
215     #[test]
send_recv_hash_map()216     fn send_recv_hash_map() {
217         let (s1, s2) = Tube::pair().unwrap();
218 
219         let mut test_msg = HashMap::new();
220         test_msg.insert("Red".to_owned(), Event::new().unwrap());
221         test_msg.insert("White".to_owned(), Event::new().unwrap());
222         test_msg.insert("Blue".to_owned(), Event::new().unwrap());
223         test_msg.insert("Orange".to_owned(), Event::new().unwrap());
224         test_msg.insert("Green".to_owned(), Event::new().unwrap());
225         s1.send(&test_msg).unwrap();
226         let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap();
227 
228         let mut test_msg_keys: Vec<_> = test_msg.keys().collect();
229         test_msg_keys.sort();
230         let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect();
231         recv_msg_keys.sort();
232         assert_eq!(test_msg_keys, recv_msg_keys);
233 
234         for (key, test_event) in test_msg {
235             let recv_event = recv_msg.remove(&key).unwrap();
236             test_event_pair(test_event, recv_event);
237         }
238     }
239 }
240