#include "uds/service_endpoint.h" #include #include #include #include #include #include // std::min #include #include #include #include #include #include #include #include namespace { constexpr int kMaxBackLogForSocketListen = 1; using android::pdx::BorrowedChannelHandle; using android::pdx::BorrowedHandle; using android::pdx::ChannelReference; using android::pdx::ErrorStatus; using android::pdx::FileReference; using android::pdx::LocalChannelHandle; using android::pdx::LocalHandle; using android::pdx::Status; using android::pdx::uds::ChannelInfo; using android::pdx::uds::ChannelManager; struct MessageState { bool GetLocalFileHandle(int index, LocalHandle* handle) { if (index < 0) { handle->Reset(index); } else if (static_cast(index) < request.file_descriptors.size()) { *handle = std::move(request.file_descriptors[index]); } else { return false; } return true; } bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) { if (index < 0) { *handle = LocalChannelHandle{nullptr, index}; } else if (static_cast(index) < request.channels.size()) { auto& channel_info = request.channels[index]; *handle = ChannelManager::Get().CreateHandle( std::move(channel_info.data_fd), std::move(channel_info.event_fd)); } else { return false; } return true; } Status PushFileHandle(BorrowedHandle handle) { if (!handle) return handle.Get(); response.file_descriptors.push_back(std::move(handle)); return response.file_descriptors.size() - 1; } Status PushChannelHandle(BorrowedChannelHandle handle) { if (!handle) return handle.value(); if (auto* channel_data = ChannelManager::Get().GetChannelData(handle.value())) { ChannelInfo channel_info; channel_info.data_fd.Reset(handle.value()); channel_info.event_fd = channel_data->event_receiver.event_fd(); response.channels.push_back(std::move(channel_info)); return response.channels.size() - 1; } else { return ErrorStatus{EINVAL}; } } Status PushChannelHandle(BorrowedHandle data_fd, BorrowedHandle event_fd) { if (!data_fd || !event_fd) return ErrorStatus{EINVAL}; ChannelInfo channel_info; channel_info.data_fd = std::move(data_fd); channel_info.event_fd = std::move(event_fd); response.channels.push_back(std::move(channel_info)); return response.channels.size() - 1; } Status WriteData(const iovec* vector, size_t vector_length) { size_t size = 0; for (size_t i = 0; i < vector_length; i++) { const auto* data = reinterpret_cast(vector[i].iov_base); response_data.insert(response_data.end(), data, data + vector[i].iov_len); size += vector[i].iov_len; } return size; } Status ReadData(const iovec* vector, size_t vector_length) { size_t size_remaining = request_data.size() - request_data_read_pos; size_t size = 0; for (size_t i = 0; i < vector_length && size_remaining > 0; i++) { size_t size_to_copy = std::min(size_remaining, vector[i].iov_len); memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos, size_to_copy); size += size_to_copy; request_data_read_pos += size_to_copy; size_remaining -= size_to_copy; } return size; } android::pdx::uds::RequestHeader request; android::pdx::uds::ResponseHeader response; std::vector sockets_to_close; std::vector request_data; size_t request_data_read_pos{0}; std::vector response_data; }; } // anonymous namespace namespace android { namespace pdx { namespace uds { Endpoint::Endpoint(const std::string& endpoint_path, bool blocking, bool use_init_socket_fd) : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)}, is_blocking_{blocking} { LocalHandle fd; if (use_init_socket_fd) { // Cut off the /dev/socket/ prefix from the full socket path and use the // resulting "name" to retrieve the file descriptor for the socket created // by the init process. constexpr char prefix[] = "/dev/socket/"; CHECK(android::base::StartsWith(endpoint_path_, prefix)) << "Endpoint::Endpoint: Socket name '" << endpoint_path_ << "' must begin with '" << prefix << "'"; std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1); fd.Reset(android_get_control_socket(socket_name.c_str())); CHECK(fd.IsValid()) << "Endpoint::Endpoint: Unable to obtain the control socket fd for '" << socket_name << "'"; fcntl(fd.Get(), F_SETFD, FD_CLOEXEC); } else { fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)); CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: " << strerror(errno); sockaddr_un local; local.sun_family = AF_UNIX; strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path)); local.sun_path[sizeof(local.sun_path) - 1] = '\0'; unlink(local.sun_path); int ret = bind(fd.Get(), reinterpret_cast(&local), sizeof(local)); CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno); } Init(std::move(fd)); } Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); } void Endpoint::Init(LocalHandle socket_fd) { if (socket_fd) { CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0) << "Endpoint::Endpoint: listen error: " << strerror(errno); } cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); CHECK(cancel_event_fd_.IsValid()) << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno); epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC)); CHECK(epoll_fd_.IsValid()) << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno); if (socket_fd) { epoll_event socket_event; socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; socket_event.data.fd = socket_fd.Get(); int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(), &socket_event); CHECK_EQ(ret, 0) << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: " << strerror(errno); } epoll_event cancel_event; cancel_event.events = EPOLLIN; cancel_event.data.fd = cancel_event_fd_.Get(); int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(), &cancel_event); CHECK_EQ(ret, 0) << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: " << strerror(errno); socket_fd_ = std::move(socket_fd); } void* Endpoint::AllocateMessageState() { return new MessageState; } void Endpoint::FreeMessageState(void* state) { delete static_cast(state); } Status Endpoint::AcceptConnection(Message* message) { if (!socket_fd_) return ErrorStatus(EBADF); sockaddr_un remote; socklen_t addrlen = sizeof(remote); LocalHandle connection_fd{accept4(socket_fd_.Get(), reinterpret_cast(&remote), &addrlen, SOCK_CLOEXEC)}; if (!connection_fd) { ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s", strerror(errno)); return ErrorStatus(errno); } LocalHandle local_socket; LocalHandle remote_socket; auto status = CreateChannelSocketPair(&local_socket, &remote_socket); if (!status) return status; // Borrow the local channel handle before we move it into OnNewChannel(). BorrowedHandle channel_handle = local_socket.Borrow(); status = OnNewChannel(std::move(local_socket)); if (!status) return status; // Send the channel socket fd to the client. ChannelConnectionInfo connection_info; connection_info.channel_fd = std::move(remote_socket); status = SendData(connection_fd.Borrow(), connection_info); if (status) { // Get the CHANNEL_OPEN message from client over the channel socket. status = ReceiveMessageForChannel(channel_handle, message); } else { CloseChannel(GetChannelId(channel_handle)); } // Don't need the connection socket anymore. Further communication should // happen over the channel socket. shutdown(connection_fd.Get(), SHUT_WR); return status; } Status Endpoint::SetService(Service* service) { service_ = service; return {}; } Status Endpoint::SetChannel(int channel_id, Channel* channel) { std::lock_guard autolock(channel_mutex_); auto channel_data = channels_.find(channel_id); if (channel_data == channels_.end()) return ErrorStatus{EINVAL}; channel_data->second.channel_state = channel; return {}; } Status Endpoint::OnNewChannel(LocalHandle channel_fd) { std::lock_guard autolock(channel_mutex_); Status status; status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr)); return status; } Status> Endpoint::OnNewChannelLocked( LocalHandle channel_fd, Channel* channel_state) { epoll_event event; event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; event.data.fd = channel_fd.Get(); if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) { ALOGE( "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n", strerror(errno)); return ErrorStatus(errno); } ChannelData channel_data; channel_data.event_set.AddDataFd(channel_fd); channel_data.data_fd = std::move(channel_fd); channel_data.channel_state = channel_state; for (;;) { // Try new channel IDs until we find one which is not already in the map. if (last_channel_id_++ == std::numeric_limits::max()) last_channel_id_ = 1; auto iter = channels_.lower_bound(last_channel_id_); if (iter == channels_.end() || iter->first != last_channel_id_) { channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_); iter = channels_.emplace_hint(iter, last_channel_id_, std::move(channel_data)); return std::make_pair(last_channel_id_, &iter->second); } } } Status Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) { epoll_event event; event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; event.data.fd = fd.Get(); if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) { ALOGE( "Endpoint::ReenableEpollEvent: Failed to re-enable channel to " "endpoint: %s\n", strerror(errno)); return ErrorStatus(errno); } return {}; } Status Endpoint::CloseChannel(int channel_id) { std::lock_guard autolock(channel_mutex_); return CloseChannelLocked(channel_id); } Status Endpoint::CloseChannelLocked(int32_t channel_id) { ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id); auto iter = channels_.find(channel_id); if (iter == channels_.end()) return ErrorStatus{EINVAL}; int channel_fd = iter->second.data_fd.Get(); Status status; epoll_event dummy; // See BUGS in man 2 epoll_ctl. if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) { status.SetError(errno); ALOGE( "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: " "%s\n", strerror(errno)); } else { status.SetValue(); } channel_fd_to_id_.erase(channel_fd); channels_.erase(iter); return status; } Status Endpoint::ModifyChannelEvents(int channel_id, int clear_mask, int set_mask) { std::lock_guard autolock(channel_mutex_); auto search = channels_.find(channel_id); if (search != channels_.end()) { auto& channel_data = search->second; channel_data.event_set.ModifyEvents(clear_mask, set_mask); return {}; } return ErrorStatus{EINVAL}; } Status Endpoint::CreateChannelSocketPair(LocalHandle* local_socket, LocalHandle* remote_socket) { Status status; char* endpoint_context = nullptr; // Make sure the channel socket has the correct SELinux label applied. // Here we get the label from the endpoint file descriptor, which should be // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace // "endpoint" with "channel" to produce the channel label such as this: // "u:object_r:pdx_service_channel_socket:s0". if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) { std::string channel_context = endpoint_context; freecon(endpoint_context); const std::string suffix = "_endpoint_socket"; auto pos = channel_context.find(suffix); if (pos != std::string::npos) { channel_context.replace(pos, suffix.size(), "_channel_socket"); } else { ALOGW( "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' " "does not contain expected substring '%s'", channel_context.c_str(), suffix.c_str()); } ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1, "Endpoint::CreateChannelSocketPair: Failed to set channel socket " "security context: %s", strerror(errno)); } else { ALOGE( "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint " "socket's security context: %s", strerror(errno)); } int channel_pair[2] = {}; if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) { ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s", strerror(errno)); status.SetError(errno); return status; } setsockcreatecon_raw(nullptr); local_socket->Reset(channel_pair[0]); remote_socket->Reset(channel_pair[1]); int optval = 1; if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval, sizeof(optval)) == -1) { ALOGE( "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of " "the credentials for channel %d: %s", local_socket->Get(), strerror(errno)); status.SetError(errno); } return status; } Status Endpoint::PushChannel(Message* message, int /*flags*/, Channel* channel, int* channel_id) { LocalHandle local_socket; LocalHandle remote_socket; auto status = CreateChannelSocketPair(&local_socket, &remote_socket); if (!status) return status.error_status(); std::lock_guard autolock(channel_mutex_); auto channel_data = OnNewChannelLocked(std::move(local_socket), channel); if (!channel_data) return channel_data.error_status(); *channel_id = channel_data.get().first; // Flags are ignored for now. // TODO(xiaohuit): Implement those. auto* state = static_cast(message->GetState()); Status ref = state->PushChannelHandle( remote_socket.Borrow(), channel_data.get().second->event_set.event_fd().Borrow()); if (!ref) return ref.error_status(); state->sockets_to_close.push_back(std::move(remote_socket)); return RemoteChannelHandle{ref.get()}; } Status Endpoint::CheckChannel(const Message* /*message*/, ChannelReference /*ref*/, Channel** /*channel*/) { // TODO(xiaohuit): Implement this. return ErrorStatus(EFAULT); } Channel* Endpoint::GetChannelState(int32_t channel_id) { std::lock_guard autolock(channel_mutex_); auto channel_data = channels_.find(channel_id); return (channel_data != channels_.end()) ? channel_data->second.channel_state : nullptr; } BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) { std::lock_guard autolock(channel_mutex_); BorrowedHandle handle; auto channel_data = channels_.find(channel_id); if (channel_data != channels_.end()) handle = channel_data->second.data_fd.Borrow(); return handle; } BorrowedHandle Endpoint::GetChannelEventFd(int32_t channel_id) { std::lock_guard autolock(channel_mutex_); BorrowedHandle handle; auto channel_data = channels_.find(channel_id); if (channel_data != channels_.end()) handle = channel_data->second.event_set.event_fd().Borrow(); return handle; } int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) { std::lock_guard autolock(channel_mutex_); auto iter = channel_fd_to_id_.find(channel_fd.Get()); return (iter != channel_fd_to_id_.end()) ? iter->second : -1; } Status Endpoint::ReceiveMessageForChannel( const BorrowedHandle& channel_fd, Message* message) { RequestHeader request; int32_t channel_id = GetChannelId(channel_fd); auto status = ReceiveData(channel_fd.Borrow(), &request); if (!status) { if (status.error() == ESHUTDOWN) { BuildCloseMessage(channel_id, message); return {}; } else { CloseChannel(channel_id); return status; } } MessageInfo info; info.pid = request.cred.pid; info.tid = -1; info.cid = channel_id; info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID : GetNextAvailableMessageId(); info.euid = request.cred.uid; info.egid = request.cred.gid; info.op = request.op; info.flags = 0; info.service = service_; info.channel = GetChannelState(channel_id); info.send_len = request.send_len; info.recv_len = request.max_recv_len; info.fd_count = request.file_descriptors.size(); static_assert(sizeof(info.impulse) == request.impulse_payload.size(), "Impulse payload sizes must be the same in RequestHeader and " "MessageInfo"); memcpy(info.impulse, request.impulse_payload.data(), request.impulse_payload.size()); *message = Message{info}; auto* state = static_cast(message->GetState()); state->request = std::move(request); if (request.send_len > 0 && !request.is_impulse) { state->request_data.resize(request.send_len); status = ReceiveData(channel_fd, state->request_data.data(), state->request_data.size()); } if (status && request.is_impulse) status = ReenableEpollEvent(channel_fd); if (!status) { if (status.error() == ESHUTDOWN) { BuildCloseMessage(channel_id, message); return {}; } else { CloseChannel(channel_id); return status; } } return status; } void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) { ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id); MessageInfo info; info.pid = -1; info.tid = -1; info.cid = channel_id; info.mid = GetNextAvailableMessageId(); info.euid = -1; info.egid = -1; info.op = opcodes::CHANNEL_CLOSE; info.flags = 0; info.service = service_; info.channel = GetChannelState(channel_id); info.send_len = 0; info.recv_len = 0; info.fd_count = 0; *message = Message{info}; } Status Endpoint::MessageReceive(Message* message) { // Receive at most one event from the epoll set. This should prevent multiple // dispatch threads from attempting to handle messages on the same socket at // the same time. epoll_event event; int count = RETRY_EINTR( epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0)); if (count < 0) { ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n", strerror(errno)); return ErrorStatus{errno}; } else if (count == 0) { return ErrorStatus{ETIMEDOUT}; } if (event.data.fd == cancel_event_fd_.Get()) { return ErrorStatus{ESHUTDOWN}; } if (socket_fd_ && event.data.fd == socket_fd_.Get()) { auto status = AcceptConnection(message); if (!status) return status; return ReenableEpollEvent(socket_fd_.Borrow()); } BorrowedHandle channel_fd{event.data.fd}; if (event.events & (EPOLLRDHUP | EPOLLHUP)) { BuildCloseMessage(GetChannelId(channel_fd), message); return {}; } return ReceiveMessageForChannel(channel_fd, message); } Status Endpoint::MessageReply(Message* message, int return_code) { const int32_t channel_id = message->GetChannelId(); auto channel_socket = GetChannelSocketFd(channel_id); if (!channel_socket) return ErrorStatus{EBADF}; auto* state = static_cast(message->GetState()); switch (message->GetOp()) { case opcodes::CHANNEL_CLOSE: return CloseChannel(channel_id); case opcodes::CHANNEL_OPEN: if (return_code < 0) { return CloseChannel(channel_id); } else { // Reply with the event fd. auto push_status = state->PushFileHandle(GetChannelEventFd(channel_id)); state->response_data.clear(); // Just in case... if (!push_status) return push_status.error_status(); return_code = push_status.get(); } break; } state->response.ret_code = return_code; state->response.recv_len = state->response_data.size(); auto status = SendData(channel_socket, state->response); if (status && !state->response_data.empty()) { status = SendData(channel_socket, state->response_data.data(), state->response_data.size()); } if (status) status = ReenableEpollEvent(channel_socket); return status; } Status Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) { auto* state = static_cast(message->GetState()); auto ref = state->PushFileHandle(BorrowedHandle{static_cast(push_fd)}); if (!ref) return ref.error_status(); return MessageReply(message, ref.get()); } Status Endpoint::MessageReplyChannelHandle( Message* message, const LocalChannelHandle& handle) { auto* state = static_cast(message->GetState()); auto ref = state->PushChannelHandle(handle.Borrow()); if (!ref) return ref.error_status(); return MessageReply(message, ref.get()); } Status Endpoint::MessageReplyChannelHandle( Message* message, const BorrowedChannelHandle& handle) { auto* state = static_cast(message->GetState()); auto ref = state->PushChannelHandle(handle.Duplicate()); if (!ref) return ref.error_status(); return MessageReply(message, ref.get()); } Status Endpoint::MessageReplyChannelHandle( Message* message, const RemoteChannelHandle& handle) { return MessageReply(message, handle.value()); } Status Endpoint::ReadMessageData(Message* message, const iovec* vector, size_t vector_length) { auto* state = static_cast(message->GetState()); return state->ReadData(vector, vector_length); } Status Endpoint::WriteMessageData(Message* message, const iovec* vector, size_t vector_length) { auto* state = static_cast(message->GetState()); return state->WriteData(vector, vector_length); } Status Endpoint::PushFileHandle(Message* message, const LocalHandle& handle) { auto* state = static_cast(message->GetState()); return state->PushFileHandle(handle.Borrow()); } Status Endpoint::PushFileHandle(Message* message, const BorrowedHandle& handle) { auto* state = static_cast(message->GetState()); return state->PushFileHandle(handle.Duplicate()); } Status Endpoint::PushFileHandle(Message* /*message*/, const RemoteHandle& handle) { return handle.Get(); } Status Endpoint::PushChannelHandle( Message* message, const LocalChannelHandle& handle) { auto* state = static_cast(message->GetState()); return state->PushChannelHandle(handle.Borrow()); } Status Endpoint::PushChannelHandle( Message* message, const BorrowedChannelHandle& handle) { auto* state = static_cast(message->GetState()); return state->PushChannelHandle(handle.Duplicate()); } Status Endpoint::PushChannelHandle( Message* /*message*/, const RemoteChannelHandle& handle) { return handle.value(); } LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const { LocalHandle handle; auto* state = static_cast(message->GetState()); state->GetLocalFileHandle(ref, &handle); return handle; } LocalChannelHandle Endpoint::GetChannelHandle(Message* message, ChannelReference ref) const { LocalChannelHandle handle; auto* state = static_cast(message->GetState()); state->GetLocalChannelHandle(ref, &handle); return handle; } Status Endpoint::Cancel() { if (eventfd_write(cancel_event_fd_.Get(), 1) < 0) return ErrorStatus{errno}; return {}; } std::unique_ptr Endpoint::Create(const std::string& endpoint_path, mode_t /*unused_mode*/, bool blocking) { return std::unique_ptr(new Endpoint(endpoint_path, blocking)); } std::unique_ptr Endpoint::CreateAndBindSocket( const std::string& endpoint_path, bool blocking) { return std::unique_ptr( new Endpoint(endpoint_path, blocking, false)); } std::unique_ptr Endpoint::CreateFromSocketFd(LocalHandle socket_fd) { return std::unique_ptr(new Endpoint(std::move(socket_fd))); } Status Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) { int optval = 1; if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval, sizeof(optval)) == -1) { ALOGE( "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving" "of the credentials for channel %d: %s", channel_fd.Get(), strerror(errno)); return ErrorStatus(errno); } return OnNewChannel(std::move(channel_fd)); } } // namespace uds } // namespace pdx } // namespace android