1 #include "uds/service_endpoint.h"
2
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm> // std::min
9
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18
19 namespace {
20
21 constexpr int kMaxBackLogForSocketListen = 1;
22
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33
34 struct MessageState {
GetLocalFileHandle__anon9a8a71d40111::MessageState35 bool GetLocalFileHandle(int index, LocalHandle* handle) {
36 if (index < 0) {
37 handle->Reset(index);
38 } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39 *handle = std::move(request.file_descriptors[index]);
40 } else {
41 return false;
42 }
43 return true;
44 }
45
GetLocalChannelHandle__anon9a8a71d40111::MessageState46 bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47 if (index < 0) {
48 *handle = LocalChannelHandle{nullptr, index};
49 } else if (static_cast<size_t>(index) < request.channels.size()) {
50 auto& channel_info = request.channels[index];
51 *handle = ChannelManager::Get().CreateHandle(
52 std::move(channel_info.data_fd), std::move(channel_info.event_fd));
53 } else {
54 return false;
55 }
56 return true;
57 }
58
PushFileHandle__anon9a8a71d40111::MessageState59 Status<FileReference> PushFileHandle(BorrowedHandle handle) {
60 if (!handle)
61 return handle.Get();
62 response.file_descriptors.push_back(std::move(handle));
63 return response.file_descriptors.size() - 1;
64 }
65
PushChannelHandle__anon9a8a71d40111::MessageState66 Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
67 if (!handle)
68 return handle.value();
69
70 if (auto* channel_data =
71 ChannelManager::Get().GetChannelData(handle.value())) {
72 ChannelInfo<BorrowedHandle> channel_info;
73 channel_info.data_fd.Reset(handle.value());
74 channel_info.event_fd = channel_data->event_receiver.event_fd();
75 response.channels.push_back(std::move(channel_info));
76 return response.channels.size() - 1;
77 } else {
78 return ErrorStatus{EINVAL};
79 }
80 }
81
PushChannelHandle__anon9a8a71d40111::MessageState82 Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
83 BorrowedHandle event_fd) {
84 if (!data_fd || !event_fd)
85 return ErrorStatus{EINVAL};
86 ChannelInfo<BorrowedHandle> channel_info;
87 channel_info.data_fd = std::move(data_fd);
88 channel_info.event_fd = std::move(event_fd);
89 response.channels.push_back(std::move(channel_info));
90 return response.channels.size() - 1;
91 }
92
WriteData__anon9a8a71d40111::MessageState93 Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
94 size_t size = 0;
95 for (size_t i = 0; i < vector_length; i++) {
96 const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
97 response_data.insert(response_data.end(), data, data + vector[i].iov_len);
98 size += vector[i].iov_len;
99 }
100 return size;
101 }
102
ReadData__anon9a8a71d40111::MessageState103 Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
104 size_t size_remaining = request_data.size() - request_data_read_pos;
105 size_t size = 0;
106 for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
107 size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
108 memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
109 size_to_copy);
110 size += size_to_copy;
111 request_data_read_pos += size_to_copy;
112 size_remaining -= size_to_copy;
113 }
114 return size;
115 }
116
117 android::pdx::uds::RequestHeader<LocalHandle> request;
118 android::pdx::uds::ResponseHeader<BorrowedHandle> response;
119 std::vector<LocalHandle> sockets_to_close;
120 std::vector<uint8_t> request_data;
121 size_t request_data_read_pos{0};
122 std::vector<uint8_t> response_data;
123 };
124
125 } // anonymous namespace
126
127 namespace android {
128 namespace pdx {
129 namespace uds {
130
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)131 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
132 bool use_init_socket_fd)
133 : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
134 is_blocking_{blocking} {
135 LocalHandle fd;
136 if (use_init_socket_fd) {
137 // Cut off the /dev/socket/ prefix from the full socket path and use the
138 // resulting "name" to retrieve the file descriptor for the socket created
139 // by the init process.
140 constexpr char prefix[] = "/dev/socket/";
141 CHECK(android::base::StartsWith(endpoint_path_, prefix))
142 << "Endpoint::Endpoint: Socket name '" << endpoint_path_
143 << "' must begin with '" << prefix << "'";
144 std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
145 fd.Reset(android_get_control_socket(socket_name.c_str()));
146 CHECK(fd.IsValid())
147 << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
148 << socket_name << "'";
149 fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
150 } else {
151 fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
152 CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
153 << strerror(errno);
154
155 sockaddr_un local;
156 local.sun_family = AF_UNIX;
157 strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
158 local.sun_path[sizeof(local.sun_path) - 1] = '\0';
159
160 unlink(local.sun_path);
161 int ret =
162 bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
163 CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
164 }
165 Init(std::move(fd));
166 }
167
Endpoint(LocalHandle socket_fd)168 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
169
Init(LocalHandle socket_fd)170 void Endpoint::Init(LocalHandle socket_fd) {
171 if (socket_fd) {
172 CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
173 << "Endpoint::Endpoint: listen error: " << strerror(errno);
174 }
175 cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
176 CHECK(cancel_event_fd_.IsValid())
177 << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
178
179 epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
180 CHECK(epoll_fd_.IsValid())
181 << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
182
183 if (socket_fd) {
184 epoll_event socket_event;
185 socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
186 socket_event.data.fd = socket_fd.Get();
187 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
188 &socket_event);
189 CHECK_EQ(ret, 0)
190 << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
191 << strerror(errno);
192 }
193
194 epoll_event cancel_event;
195 cancel_event.events = EPOLLIN;
196 cancel_event.data.fd = cancel_event_fd_.Get();
197
198 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
199 &cancel_event);
200 CHECK_EQ(ret, 0)
201 << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
202 << strerror(errno);
203 socket_fd_ = std::move(socket_fd);
204 }
205
AllocateMessageState()206 void* Endpoint::AllocateMessageState() { return new MessageState; }
207
FreeMessageState(void * state)208 void Endpoint::FreeMessageState(void* state) {
209 delete static_cast<MessageState*>(state);
210 }
211
AcceptConnection(Message * message)212 Status<void> Endpoint::AcceptConnection(Message* message) {
213 if (!socket_fd_)
214 return ErrorStatus(EBADF);
215
216 sockaddr_un remote;
217 socklen_t addrlen = sizeof(remote);
218 LocalHandle connection_fd{accept4(socket_fd_.Get(),
219 reinterpret_cast<sockaddr*>(&remote),
220 &addrlen, SOCK_CLOEXEC)};
221 if (!connection_fd) {
222 ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
223 strerror(errno));
224 return ErrorStatus(errno);
225 }
226
227 LocalHandle local_socket;
228 LocalHandle remote_socket;
229 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
230 if (!status)
231 return status;
232
233 // Borrow the local channel handle before we move it into OnNewChannel().
234 BorrowedHandle channel_handle = local_socket.Borrow();
235 status = OnNewChannel(std::move(local_socket));
236 if (!status)
237 return status;
238
239 // Send the channel socket fd to the client.
240 ChannelConnectionInfo<LocalHandle> connection_info;
241 connection_info.channel_fd = std::move(remote_socket);
242 status = SendData(connection_fd.Borrow(), connection_info);
243
244 if (status) {
245 // Get the CHANNEL_OPEN message from client over the channel socket.
246 status = ReceiveMessageForChannel(channel_handle, message);
247 } else {
248 CloseChannel(GetChannelId(channel_handle));
249 }
250
251 // Don't need the connection socket anymore. Further communication should
252 // happen over the channel socket.
253 shutdown(connection_fd.Get(), SHUT_WR);
254 return status;
255 }
256
SetService(Service * service)257 Status<void> Endpoint::SetService(Service* service) {
258 service_ = service;
259 return {};
260 }
261
SetChannel(int channel_id,Channel * channel)262 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
263 std::lock_guard<std::mutex> autolock(channel_mutex_);
264 auto channel_data = channels_.find(channel_id);
265 if (channel_data == channels_.end())
266 return ErrorStatus{EINVAL};
267 channel_data->second.channel_state = channel;
268 return {};
269 }
270
OnNewChannel(LocalHandle channel_fd)271 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
272 std::lock_guard<std::mutex> autolock(channel_mutex_);
273 Status<void> status;
274 status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
275 return status;
276 }
277
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)278 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
279 LocalHandle channel_fd, Channel* channel_state) {
280 epoll_event event;
281 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
282 event.data.fd = channel_fd.Get();
283 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
284 ALOGE(
285 "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
286 strerror(errno));
287 return ErrorStatus(errno);
288 }
289 ChannelData channel_data;
290 channel_data.event_set.AddDataFd(channel_fd);
291 channel_data.data_fd = std::move(channel_fd);
292 channel_data.channel_state = channel_state;
293 for (;;) {
294 // Try new channel IDs until we find one which is not already in the map.
295 if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
296 last_channel_id_ = 1;
297 auto iter = channels_.lower_bound(last_channel_id_);
298 if (iter == channels_.end() || iter->first != last_channel_id_) {
299 channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
300 iter = channels_.emplace_hint(iter, last_channel_id_,
301 std::move(channel_data));
302 return std::make_pair(last_channel_id_, &iter->second);
303 }
304 }
305 }
306
ReenableEpollEvent(const BorrowedHandle & fd)307 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
308 epoll_event event;
309 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
310 event.data.fd = fd.Get();
311 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
312 ALOGE(
313 "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
314 "endpoint: %s\n",
315 strerror(errno));
316 return ErrorStatus(errno);
317 }
318 return {};
319 }
320
CloseChannel(int channel_id)321 Status<void> Endpoint::CloseChannel(int channel_id) {
322 std::lock_guard<std::mutex> autolock(channel_mutex_);
323 return CloseChannelLocked(channel_id);
324 }
325
CloseChannelLocked(int32_t channel_id)326 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
327 ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
328
329 auto iter = channels_.find(channel_id);
330 if (iter == channels_.end())
331 return ErrorStatus{EINVAL};
332
333 int channel_fd = iter->second.data_fd.Get();
334 Status<void> status;
335 epoll_event dummy; // See BUGS in man 2 epoll_ctl.
336 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) {
337 status.SetError(errno);
338 ALOGE(
339 "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
340 "%s\n",
341 strerror(errno));
342 } else {
343 status.SetValue();
344 }
345
346 channel_fd_to_id_.erase(channel_fd);
347 channels_.erase(iter);
348 return status;
349 }
350
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)351 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
352 int set_mask) {
353 std::lock_guard<std::mutex> autolock(channel_mutex_);
354
355 auto search = channels_.find(channel_id);
356 if (search != channels_.end()) {
357 auto& channel_data = search->second;
358 channel_data.event_set.ModifyEvents(clear_mask, set_mask);
359 return {};
360 }
361
362 return ErrorStatus{EINVAL};
363 }
364
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)365 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
366 LocalHandle* remote_socket) {
367 Status<void> status;
368 char* endpoint_context = nullptr;
369 // Make sure the channel socket has the correct SELinux label applied.
370 // Here we get the label from the endpoint file descriptor, which should be
371 // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
372 // "endpoint" with "channel" to produce the channel label such as this:
373 // "u:object_r:pdx_service_channel_socket:s0".
374 if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
375 std::string channel_context = endpoint_context;
376 freecon(endpoint_context);
377 const std::string suffix = "_endpoint_socket";
378 auto pos = channel_context.find(suffix);
379 if (pos != std::string::npos) {
380 channel_context.replace(pos, suffix.size(), "_channel_socket");
381 } else {
382 ALOGW(
383 "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
384 "does not contain expected substring '%s'",
385 channel_context.c_str(), suffix.c_str());
386 }
387 ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
388 "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
389 "security context: %s",
390 strerror(errno));
391 } else {
392 ALOGE(
393 "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
394 "socket's security context: %s",
395 strerror(errno));
396 }
397
398 int channel_pair[2] = {};
399 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
400 ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
401 strerror(errno));
402 status.SetError(errno);
403 return status;
404 }
405
406 setsockcreatecon_raw(nullptr);
407
408 local_socket->Reset(channel_pair[0]);
409 remote_socket->Reset(channel_pair[1]);
410
411 int optval = 1;
412 if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
413 sizeof(optval)) == -1) {
414 ALOGE(
415 "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
416 "the credentials for channel %d: %s",
417 local_socket->Get(), strerror(errno));
418 status.SetError(errno);
419 }
420 return status;
421 }
422
PushChannel(Message * message,int,Channel * channel,int * channel_id)423 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
424 int /*flags*/,
425 Channel* channel,
426 int* channel_id) {
427 LocalHandle local_socket;
428 LocalHandle remote_socket;
429 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
430 if (!status)
431 return status.error_status();
432
433 std::lock_guard<std::mutex> autolock(channel_mutex_);
434 auto channel_data = OnNewChannelLocked(std::move(local_socket), channel);
435 if (!channel_data)
436 return channel_data.error_status();
437 *channel_id = channel_data.get().first;
438
439 // Flags are ignored for now.
440 // TODO(xiaohuit): Implement those.
441
442 auto* state = static_cast<MessageState*>(message->GetState());
443 Status<ChannelReference> ref = state->PushChannelHandle(
444 remote_socket.Borrow(),
445 channel_data.get().second->event_set.event_fd().Borrow());
446 if (!ref)
447 return ref.error_status();
448 state->sockets_to_close.push_back(std::move(remote_socket));
449 return RemoteChannelHandle{ref.get()};
450 }
451
CheckChannel(const Message *,ChannelReference,Channel **)452 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
453 ChannelReference /*ref*/,
454 Channel** /*channel*/) {
455 // TODO(xiaohuit): Implement this.
456 return ErrorStatus(EFAULT);
457 }
458
GetChannelState(int32_t channel_id)459 Channel* Endpoint::GetChannelState(int32_t channel_id) {
460 std::lock_guard<std::mutex> autolock(channel_mutex_);
461 auto channel_data = channels_.find(channel_id);
462 return (channel_data != channels_.end()) ? channel_data->second.channel_state
463 : nullptr;
464 }
465
GetChannelSocketFd(int32_t channel_id)466 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
467 std::lock_guard<std::mutex> autolock(channel_mutex_);
468 BorrowedHandle handle;
469 auto channel_data = channels_.find(channel_id);
470 if (channel_data != channels_.end())
471 handle = channel_data->second.data_fd.Borrow();
472 return handle;
473 }
474
GetChannelEventFd(int32_t channel_id)475 BorrowedHandle Endpoint::GetChannelEventFd(int32_t channel_id) {
476 std::lock_guard<std::mutex> autolock(channel_mutex_);
477 BorrowedHandle handle;
478 auto channel_data = channels_.find(channel_id);
479 if (channel_data != channels_.end())
480 handle = channel_data->second.event_set.event_fd().Borrow();
481 return handle;
482 }
483
GetChannelId(const BorrowedHandle & channel_fd)484 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
485 std::lock_guard<std::mutex> autolock(channel_mutex_);
486 auto iter = channel_fd_to_id_.find(channel_fd.Get());
487 return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
488 }
489
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)490 Status<void> Endpoint::ReceiveMessageForChannel(
491 const BorrowedHandle& channel_fd, Message* message) {
492 RequestHeader<LocalHandle> request;
493 int32_t channel_id = GetChannelId(channel_fd);
494 auto status = ReceiveData(channel_fd.Borrow(), &request);
495 if (!status) {
496 if (status.error() == ESHUTDOWN) {
497 BuildCloseMessage(channel_id, message);
498 return {};
499 } else {
500 CloseChannel(channel_id);
501 return status;
502 }
503 }
504
505 MessageInfo info;
506 info.pid = request.cred.pid;
507 info.tid = -1;
508 info.cid = channel_id;
509 info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
510 : GetNextAvailableMessageId();
511 info.euid = request.cred.uid;
512 info.egid = request.cred.gid;
513 info.op = request.op;
514 info.flags = 0;
515 info.service = service_;
516 info.channel = GetChannelState(channel_id);
517 info.send_len = request.send_len;
518 info.recv_len = request.max_recv_len;
519 info.fd_count = request.file_descriptors.size();
520 static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
521 "Impulse payload sizes must be the same in RequestHeader and "
522 "MessageInfo");
523 memcpy(info.impulse, request.impulse_payload.data(),
524 request.impulse_payload.size());
525 *message = Message{info};
526 auto* state = static_cast<MessageState*>(message->GetState());
527 state->request = std::move(request);
528 if (request.send_len > 0 && !request.is_impulse) {
529 state->request_data.resize(request.send_len);
530 status = ReceiveData(channel_fd, state->request_data.data(),
531 state->request_data.size());
532 }
533
534 if (status && request.is_impulse)
535 status = ReenableEpollEvent(channel_fd);
536
537 if (!status) {
538 if (status.error() == ESHUTDOWN) {
539 BuildCloseMessage(channel_id, message);
540 return {};
541 } else {
542 CloseChannel(channel_id);
543 return status;
544 }
545 }
546
547 return status;
548 }
549
BuildCloseMessage(int32_t channel_id,Message * message)550 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
551 ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
552 MessageInfo info;
553 info.pid = -1;
554 info.tid = -1;
555 info.cid = channel_id;
556 info.mid = GetNextAvailableMessageId();
557 info.euid = -1;
558 info.egid = -1;
559 info.op = opcodes::CHANNEL_CLOSE;
560 info.flags = 0;
561 info.service = service_;
562 info.channel = GetChannelState(channel_id);
563 info.send_len = 0;
564 info.recv_len = 0;
565 info.fd_count = 0;
566 *message = Message{info};
567 }
568
MessageReceive(Message * message)569 Status<void> Endpoint::MessageReceive(Message* message) {
570 // Receive at most one event from the epoll set. This should prevent multiple
571 // dispatch threads from attempting to handle messages on the same socket at
572 // the same time.
573 epoll_event event;
574 int count = RETRY_EINTR(
575 epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
576 if (count < 0) {
577 ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
578 strerror(errno));
579 return ErrorStatus{errno};
580 } else if (count == 0) {
581 return ErrorStatus{ETIMEDOUT};
582 }
583
584 if (event.data.fd == cancel_event_fd_.Get()) {
585 return ErrorStatus{ESHUTDOWN};
586 }
587
588 if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
589 auto status = AcceptConnection(message);
590 if (!status)
591 return status;
592 return ReenableEpollEvent(socket_fd_.Borrow());
593 }
594
595 BorrowedHandle channel_fd{event.data.fd};
596 if (event.events & (EPOLLRDHUP | EPOLLHUP)) {
597 BuildCloseMessage(GetChannelId(channel_fd), message);
598 return {};
599 }
600
601 return ReceiveMessageForChannel(channel_fd, message);
602 }
603
MessageReply(Message * message,int return_code)604 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
605 const int32_t channel_id = message->GetChannelId();
606 auto channel_socket = GetChannelSocketFd(channel_id);
607 if (!channel_socket)
608 return ErrorStatus{EBADF};
609
610 auto* state = static_cast<MessageState*>(message->GetState());
611 switch (message->GetOp()) {
612 case opcodes::CHANNEL_CLOSE:
613 return CloseChannel(channel_id);
614
615 case opcodes::CHANNEL_OPEN:
616 if (return_code < 0) {
617 return CloseChannel(channel_id);
618 } else {
619 // Reply with the event fd.
620 auto push_status = state->PushFileHandle(GetChannelEventFd(channel_id));
621 state->response_data.clear(); // Just in case...
622 if (!push_status)
623 return push_status.error_status();
624 return_code = push_status.get();
625 }
626 break;
627 }
628
629 state->response.ret_code = return_code;
630 state->response.recv_len = state->response_data.size();
631 auto status = SendData(channel_socket, state->response);
632 if (status && !state->response_data.empty()) {
633 status = SendData(channel_socket, state->response_data.data(),
634 state->response_data.size());
635 }
636
637 if (status)
638 status = ReenableEpollEvent(channel_socket);
639
640 return status;
641 }
642
MessageReplyFd(Message * message,unsigned int push_fd)643 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
644 auto* state = static_cast<MessageState*>(message->GetState());
645 auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
646 if (!ref)
647 return ref.error_status();
648 return MessageReply(message, ref.get());
649 }
650
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)651 Status<void> Endpoint::MessageReplyChannelHandle(
652 Message* message, const LocalChannelHandle& handle) {
653 auto* state = static_cast<MessageState*>(message->GetState());
654 auto ref = state->PushChannelHandle(handle.Borrow());
655 if (!ref)
656 return ref.error_status();
657 return MessageReply(message, ref.get());
658 }
659
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)660 Status<void> Endpoint::MessageReplyChannelHandle(
661 Message* message, const BorrowedChannelHandle& handle) {
662 auto* state = static_cast<MessageState*>(message->GetState());
663 auto ref = state->PushChannelHandle(handle.Duplicate());
664 if (!ref)
665 return ref.error_status();
666 return MessageReply(message, ref.get());
667 }
668
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)669 Status<void> Endpoint::MessageReplyChannelHandle(
670 Message* message, const RemoteChannelHandle& handle) {
671 return MessageReply(message, handle.value());
672 }
673
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)674 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
675 size_t vector_length) {
676 auto* state = static_cast<MessageState*>(message->GetState());
677 return state->ReadData(vector, vector_length);
678 }
679
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)680 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
681 size_t vector_length) {
682 auto* state = static_cast<MessageState*>(message->GetState());
683 return state->WriteData(vector, vector_length);
684 }
685
PushFileHandle(Message * message,const LocalHandle & handle)686 Status<FileReference> Endpoint::PushFileHandle(Message* message,
687 const LocalHandle& handle) {
688 auto* state = static_cast<MessageState*>(message->GetState());
689 return state->PushFileHandle(handle.Borrow());
690 }
691
PushFileHandle(Message * message,const BorrowedHandle & handle)692 Status<FileReference> Endpoint::PushFileHandle(Message* message,
693 const BorrowedHandle& handle) {
694 auto* state = static_cast<MessageState*>(message->GetState());
695 return state->PushFileHandle(handle.Duplicate());
696 }
697
PushFileHandle(Message *,const RemoteHandle & handle)698 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
699 const RemoteHandle& handle) {
700 return handle.Get();
701 }
702
PushChannelHandle(Message * message,const LocalChannelHandle & handle)703 Status<ChannelReference> Endpoint::PushChannelHandle(
704 Message* message, const LocalChannelHandle& handle) {
705 auto* state = static_cast<MessageState*>(message->GetState());
706 return state->PushChannelHandle(handle.Borrow());
707 }
708
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)709 Status<ChannelReference> Endpoint::PushChannelHandle(
710 Message* message, const BorrowedChannelHandle& handle) {
711 auto* state = static_cast<MessageState*>(message->GetState());
712 return state->PushChannelHandle(handle.Duplicate());
713 }
714
PushChannelHandle(Message *,const RemoteChannelHandle & handle)715 Status<ChannelReference> Endpoint::PushChannelHandle(
716 Message* /*message*/, const RemoteChannelHandle& handle) {
717 return handle.value();
718 }
719
GetFileHandle(Message * message,FileReference ref) const720 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
721 LocalHandle handle;
722 auto* state = static_cast<MessageState*>(message->GetState());
723 state->GetLocalFileHandle(ref, &handle);
724 return handle;
725 }
726
GetChannelHandle(Message * message,ChannelReference ref) const727 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
728 ChannelReference ref) const {
729 LocalChannelHandle handle;
730 auto* state = static_cast<MessageState*>(message->GetState());
731 state->GetLocalChannelHandle(ref, &handle);
732 return handle;
733 }
734
Cancel()735 Status<void> Endpoint::Cancel() {
736 if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
737 return ErrorStatus{errno};
738 return {};
739 }
740
Create(const std::string & endpoint_path,mode_t,bool blocking)741 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
742 mode_t /*unused_mode*/,
743 bool blocking) {
744 return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
745 }
746
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)747 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
748 const std::string& endpoint_path, bool blocking) {
749 return std::unique_ptr<Endpoint>(
750 new Endpoint(endpoint_path, blocking, false));
751 }
752
CreateFromSocketFd(LocalHandle socket_fd)753 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
754 return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
755 }
756
RegisterNewChannelForTests(LocalHandle channel_fd)757 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
758 int optval = 1;
759 if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
760 sizeof(optval)) == -1) {
761 ALOGE(
762 "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
763 "of the credentials for channel %d: %s",
764 channel_fd.Get(), strerror(errno));
765 return ErrorStatus(errno);
766 }
767 return OnNewChannel(std::move(channel_fd));
768 }
769
770 } // namespace uds
771 } // namespace pdx
772 } // namespace android
773