1 // Copyright 2019 The Chromium OS Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 use std::{array::TryFromSliceError, convert::TryInto, error, fmt, io, mem, os::unix::io::RawFd}; 5 6 use cras_sys::gen::{ 7 cras_client_connected, cras_client_message, cras_client_stream_connected, 8 CRAS_CLIENT_MAX_MSG_SIZE, 9 CRAS_CLIENT_MESSAGE_ID::{self, *}, 10 }; 11 use data_model::DataInit; 12 use sys_util::ScmSocket; 13 14 use crate::cras_server_socket::CrasServerSocket; 15 use crate::cras_shm::*; 16 use crate::cras_stream; 17 18 #[derive(Debug)] 19 pub enum Error { 20 IoError(io::Error), 21 SysUtilError(sys_util::Error), 22 CrasStreamError(cras_stream::Error), 23 ArrayTryFromSliceError(TryFromSliceError), 24 InvalidSize, 25 MessageTypeError, 26 MessageNumFdError, 27 MessageTruncated, 28 MessageIdError, 29 MessageFromSliceError, 30 } 31 32 impl error::Error for Error {} 33 34 impl fmt::Display for Error { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result35 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 36 match self { 37 Error::IoError(ref err) => err.fmt(f), 38 Error::SysUtilError(ref err) => err.fmt(f), 39 Error::MessageTypeError => write!(f, "Message type error"), 40 Error::CrasStreamError(ref err) => err.fmt(f), 41 Error::ArrayTryFromSliceError(ref err) => err.fmt(f), 42 Error::MessageNumFdError => write!(f, "Message the number of fds is not matched"), 43 Error::MessageTruncated => write!(f, "Read truncated message"), 44 Error::MessageIdError => write!(f, "No such id"), 45 Error::MessageFromSliceError => write!(f, "Message from slice error"), 46 Error::InvalidSize => write!(f, "Invalid data size"), 47 } 48 } 49 } 50 51 type Result<T> = std::result::Result<T, Error>; 52 53 impl From<io::Error> for Error { from(io_err: io::Error) -> Self54 fn from(io_err: io::Error) -> Self { 55 Error::IoError(io_err) 56 } 57 } 58 59 impl From<sys_util::Error> for Error { from(sys_util_err: sys_util::Error) -> Self60 fn from(sys_util_err: sys_util::Error) -> Self { 61 Error::SysUtilError(sys_util_err) 62 } 63 } 64 65 impl From<cras_stream::Error> for Error { from(err: cras_stream::Error) -> Self66 fn from(err: cras_stream::Error) -> Self { 67 Error::CrasStreamError(err) 68 } 69 } 70 71 impl From<TryFromSliceError> for Error { from(err: TryFromSliceError) -> Self72 fn from(err: TryFromSliceError) -> Self { 73 Error::ArrayTryFromSliceError(err) 74 } 75 } 76 77 /// A handled server result from one message sent from CRAS server. 78 pub enum ServerResult { 79 /// client_id, CrasServerStateShmFd 80 Connected(u32, CrasServerStateShmFd), 81 /// stream_id, header_fd, samples_fd 82 StreamConnected(u32, CrasAudioShmHeaderFd, CrasShmFd), 83 DebugInfoReady, 84 } 85 86 impl ServerResult { 87 /// Reads and handles one server message and converts `CrasClientMessage` into `ServerResult` 88 /// with error handling. 89 /// 90 /// # Arguments 91 /// * `server_socket`: A reference to `CrasServerSocket`. handle_server_message(server_socket: &CrasServerSocket) -> Result<ServerResult>92 pub fn handle_server_message(server_socket: &CrasServerSocket) -> Result<ServerResult> { 93 let message = CrasClientMessage::try_new(&server_socket)?; 94 match message.get_id()? { 95 CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_CONNECTED => { 96 let cmsg: &cras_client_connected = message.get_message()?; 97 // CRAS server should return a shared memory area which contains 98 // `cras_server_state`. 99 let server_state_fd = unsafe { CrasServerStateShmFd::new(message.fds[0]) }; 100 Ok(ServerResult::Connected(cmsg.client_id, server_state_fd)) 101 } 102 CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_STREAM_CONNECTED => { 103 let cmsg: &cras_client_stream_connected = message.get_message()?; 104 // CRAS should return two shared memory areas the first which has 105 // mem::size_of::<cras_audio_shm_header>() bytes, and the second which has 106 // `samples_shm_size` bytes. 107 Ok(ServerResult::StreamConnected( 108 cmsg.stream_id, 109 // Safe because CRAS ensures that the first fd contains a cras_audio_shm_header 110 unsafe { CrasAudioShmHeaderFd::new(message.fds[0]) }, 111 // Safe because CRAS ensures that the second fd has length 'samples_shm_size' 112 unsafe { CrasShmFd::new(message.fds[1], cmsg.samples_shm_size as usize) }, 113 )) 114 } 115 CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => { 116 Ok(ServerResult::DebugInfoReady) 117 } 118 _ => Err(Error::MessageTypeError), 119 } 120 } 121 } 122 123 // A structure for raw message with fds from CRAS server. 124 struct CrasClientMessage { 125 fds: [RawFd; 2], 126 data: [u8; CRAS_CLIENT_MAX_MSG_SIZE as usize], 127 len: usize, 128 } 129 130 /// The default constructor won't be used outside of this file and it's an optimization to prevent 131 /// having to copy the message data from a temp buffer. 132 impl Default for CrasClientMessage { 133 // Initializes fields with default values. default() -> Self134 fn default() -> Self { 135 Self { 136 fds: [-1; 2], 137 data: [0; CRAS_CLIENT_MAX_MSG_SIZE as usize], 138 len: 0, 139 } 140 } 141 } 142 143 impl CrasClientMessage { 144 // Reads a message from server_socket and checks validity of the read result try_new(server_socket: &CrasServerSocket) -> Result<CrasClientMessage>145 fn try_new(server_socket: &CrasServerSocket) -> Result<CrasClientMessage> { 146 let mut message: Self = Default::default(); 147 let (len, fd_nums) = server_socket.recv_with_fds(&mut message.data, &mut message.fds)?; 148 149 if len < mem::size_of::<cras_client_message>() { 150 Err(Error::MessageTruncated) 151 } else { 152 message.len = len; 153 message.check_fd_nums(fd_nums)?; 154 Ok(message) 155 } 156 } 157 158 // Check if `fd nums` of a read result is valid check_fd_nums(&self, fd_nums: usize) -> Result<()>159 fn check_fd_nums(&self, fd_nums: usize) -> Result<()> { 160 match self.get_id()? { 161 CRAS_CLIENT_CONNECTED => match fd_nums { 162 1 => Ok(()), 163 _ => Err(Error::MessageNumFdError), 164 }, 165 CRAS_CLIENT_STREAM_CONNECTED => match fd_nums { 166 // CRAS should return two shared memory areas the first which has 167 // mem::size_of::<cras_audio_shm_header>() bytes, and the second which has 168 // `samples_shm_size` bytes. 169 2 => Ok(()), 170 _ => Err(Error::MessageNumFdError), 171 }, 172 CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => match fd_nums { 173 0 => Ok(()), 174 _ => Err(Error::MessageNumFdError), 175 }, 176 _ => Err(Error::MessageTypeError), 177 } 178 } 179 180 // Gets the message id get_id(&self) -> Result<CRAS_CLIENT_MESSAGE_ID>181 fn get_id(&self) -> Result<CRAS_CLIENT_MESSAGE_ID> { 182 let offset = mem::size_of::<u32>(); 183 match u32::from_le_bytes(self.data[offset..offset + 4].try_into()?) { 184 id if id == (CRAS_CLIENT_CONNECTED as u32) => Ok(CRAS_CLIENT_CONNECTED), 185 id if id == (CRAS_CLIENT_STREAM_CONNECTED as u32) => Ok(CRAS_CLIENT_STREAM_CONNECTED), 186 id if id == (CRAS_CLIENT_AUDIO_DEBUG_INFO_READY as u32) => { 187 Ok(CRAS_CLIENT_AUDIO_DEBUG_INFO_READY) 188 } 189 _ => Err(Error::MessageIdError), 190 } 191 } 192 193 // Gets a reference to the message content get_message<T: DataInit>(&self) -> Result<&T>194 fn get_message<T: DataInit>(&self) -> Result<&T> { 195 if self.len != mem::size_of::<T>() { 196 return Err(Error::InvalidSize); 197 } 198 T::from_slice(&self.data[..mem::size_of::<T>()]).ok_or(Error::MessageFromSliceError) 199 } 200 } 201