1 // Copyright 2019 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 use std::{array::TryFromSliceError, convert::TryInto, error, fmt, io, mem, os::unix::io::RawFd};
5 
6 use cras_sys::gen::{
7     cras_client_connected, cras_client_message, cras_client_stream_connected,
8     CRAS_CLIENT_MAX_MSG_SIZE,
9     CRAS_CLIENT_MESSAGE_ID::{self, *},
10 };
11 use data_model::DataInit;
12 use sys_util::ScmSocket;
13 
14 use crate::cras_server_socket::CrasServerSocket;
15 use crate::cras_shm::*;
16 use crate::cras_stream;
17 
18 #[derive(Debug)]
19 pub enum Error {
20     IoError(io::Error),
21     SysUtilError(sys_util::Error),
22     CrasStreamError(cras_stream::Error),
23     ArrayTryFromSliceError(TryFromSliceError),
24     InvalidSize,
25     MessageTypeError,
26     MessageNumFdError,
27     MessageTruncated,
28     MessageIdError,
29     MessageFromSliceError,
30 }
31 
32 impl error::Error for Error {}
33 
34 impl fmt::Display for Error {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result35     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
36         match self {
37             Error::IoError(ref err) => err.fmt(f),
38             Error::SysUtilError(ref err) => err.fmt(f),
39             Error::MessageTypeError => write!(f, "Message type error"),
40             Error::CrasStreamError(ref err) => err.fmt(f),
41             Error::ArrayTryFromSliceError(ref err) => err.fmt(f),
42             Error::MessageNumFdError => write!(f, "Message the number of fds is not matched"),
43             Error::MessageTruncated => write!(f, "Read truncated message"),
44             Error::MessageIdError => write!(f, "No such id"),
45             Error::MessageFromSliceError => write!(f, "Message from slice error"),
46             Error::InvalidSize => write!(f, "Invalid data size"),
47         }
48     }
49 }
50 
51 type Result<T> = std::result::Result<T, Error>;
52 
53 impl From<io::Error> for Error {
from(io_err: io::Error) -> Self54     fn from(io_err: io::Error) -> Self {
55         Error::IoError(io_err)
56     }
57 }
58 
59 impl From<sys_util::Error> for Error {
from(sys_util_err: sys_util::Error) -> Self60     fn from(sys_util_err: sys_util::Error) -> Self {
61         Error::SysUtilError(sys_util_err)
62     }
63 }
64 
65 impl From<cras_stream::Error> for Error {
from(err: cras_stream::Error) -> Self66     fn from(err: cras_stream::Error) -> Self {
67         Error::CrasStreamError(err)
68     }
69 }
70 
71 impl From<TryFromSliceError> for Error {
from(err: TryFromSliceError) -> Self72     fn from(err: TryFromSliceError) -> Self {
73         Error::ArrayTryFromSliceError(err)
74     }
75 }
76 
77 /// A handled server result from one message sent from CRAS server.
78 pub enum ServerResult {
79     /// client_id, CrasServerStateShmFd
80     Connected(u32, CrasServerStateShmFd),
81     /// stream_id, header_fd, samples_fd
82     StreamConnected(u32, CrasAudioShmHeaderFd, CrasShmFd),
83     DebugInfoReady,
84 }
85 
86 impl ServerResult {
87     /// Reads and handles one server message and converts `CrasClientMessage` into `ServerResult`
88     /// with error handling.
89     ///
90     /// # Arguments
91     /// * `server_socket`: A reference to `CrasServerSocket`.
handle_server_message(server_socket: &CrasServerSocket) -> Result<ServerResult>92     pub fn handle_server_message(server_socket: &CrasServerSocket) -> Result<ServerResult> {
93         let message = CrasClientMessage::try_new(&server_socket)?;
94         match message.get_id()? {
95             CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_CONNECTED => {
96                 let cmsg: &cras_client_connected = message.get_message()?;
97                 // CRAS server should return a shared memory area which contains
98                 // `cras_server_state`.
99                 let server_state_fd = unsafe { CrasServerStateShmFd::new(message.fds[0]) };
100                 Ok(ServerResult::Connected(cmsg.client_id, server_state_fd))
101             }
102             CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_STREAM_CONNECTED => {
103                 let cmsg: &cras_client_stream_connected = message.get_message()?;
104                 // CRAS should return two shared memory areas the first which has
105                 // mem::size_of::<cras_audio_shm_header>() bytes, and the second which has
106                 // `samples_shm_size` bytes.
107                 Ok(ServerResult::StreamConnected(
108                     cmsg.stream_id,
109                     // Safe because CRAS ensures that the first fd contains a cras_audio_shm_header
110                     unsafe { CrasAudioShmHeaderFd::new(message.fds[0]) },
111                     // Safe because CRAS ensures that the second fd has length 'samples_shm_size'
112                     unsafe { CrasShmFd::new(message.fds[1], cmsg.samples_shm_size as usize) },
113                 ))
114             }
115             CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => {
116                 Ok(ServerResult::DebugInfoReady)
117             }
118             _ => Err(Error::MessageTypeError),
119         }
120     }
121 }
122 
123 // A structure for raw message with fds from CRAS server.
124 struct CrasClientMessage {
125     fds: [RawFd; 2],
126     data: [u8; CRAS_CLIENT_MAX_MSG_SIZE as usize],
127     len: usize,
128 }
129 
130 /// The default constructor won't be used outside of this file and it's an optimization to prevent
131 /// having to copy the message data from a temp buffer.
132 impl Default for CrasClientMessage {
133     // Initializes fields with default values.
default() -> Self134     fn default() -> Self {
135         Self {
136             fds: [-1; 2],
137             data: [0; CRAS_CLIENT_MAX_MSG_SIZE as usize],
138             len: 0,
139         }
140     }
141 }
142 
143 impl CrasClientMessage {
144     // Reads a message from server_socket and checks validity of the read result
try_new(server_socket: &CrasServerSocket) -> Result<CrasClientMessage>145     fn try_new(server_socket: &CrasServerSocket) -> Result<CrasClientMessage> {
146         let mut message: Self = Default::default();
147         let (len, fd_nums) = server_socket.recv_with_fds(&mut message.data, &mut message.fds)?;
148 
149         if len < mem::size_of::<cras_client_message>() {
150             Err(Error::MessageTruncated)
151         } else {
152             message.len = len;
153             message.check_fd_nums(fd_nums)?;
154             Ok(message)
155         }
156     }
157 
158     // Check if `fd nums` of a read result is valid
check_fd_nums(&self, fd_nums: usize) -> Result<()>159     fn check_fd_nums(&self, fd_nums: usize) -> Result<()> {
160         match self.get_id()? {
161             CRAS_CLIENT_CONNECTED => match fd_nums {
162                 1 => Ok(()),
163                 _ => Err(Error::MessageNumFdError),
164             },
165             CRAS_CLIENT_STREAM_CONNECTED => match fd_nums {
166                 // CRAS should return two shared memory areas the first which has
167                 // mem::size_of::<cras_audio_shm_header>() bytes, and the second which has
168                 // `samples_shm_size` bytes.
169                 2 => Ok(()),
170                 _ => Err(Error::MessageNumFdError),
171             },
172             CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => match fd_nums {
173                 0 => Ok(()),
174                 _ => Err(Error::MessageNumFdError),
175             },
176             _ => Err(Error::MessageTypeError),
177         }
178     }
179 
180     // Gets the message id
get_id(&self) -> Result<CRAS_CLIENT_MESSAGE_ID>181     fn get_id(&self) -> Result<CRAS_CLIENT_MESSAGE_ID> {
182         let offset = mem::size_of::<u32>();
183         match u32::from_le_bytes(self.data[offset..offset + 4].try_into()?) {
184             id if id == (CRAS_CLIENT_CONNECTED as u32) => Ok(CRAS_CLIENT_CONNECTED),
185             id if id == (CRAS_CLIENT_STREAM_CONNECTED as u32) => Ok(CRAS_CLIENT_STREAM_CONNECTED),
186             id if id == (CRAS_CLIENT_AUDIO_DEBUG_INFO_READY as u32) => {
187                 Ok(CRAS_CLIENT_AUDIO_DEBUG_INFO_READY)
188             }
189             _ => Err(Error::MessageIdError),
190         }
191     }
192 
193     // Gets a reference to the message content
get_message<T: DataInit>(&self) -> Result<&T>194     fn get_message<T: DataInit>(&self) -> Result<&T> {
195         if self.len != mem::size_of::<T>() {
196             return Err(Error::InvalidSize);
197         }
198         T::from_slice(&self.data[..mem::size_of::<T>()]).ok_or(Error::MessageFromSliceError)
199     }
200 }
201