1 #include "uds/service_endpoint.h"
2
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm> // std::min
9
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18
19 namespace {
20
21 constexpr int kMaxBackLogForSocketListen = 1;
22
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33
34 struct MessageState {
GetLocalFileHandle__anon9a8a71d40111::MessageState35 bool GetLocalFileHandle(int index, LocalHandle* handle) {
36 if (index < 0) {
37 handle->Reset(index);
38 } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39 *handle = std::move(request.file_descriptors[index]);
40 } else {
41 return false;
42 }
43 return true;
44 }
45
GetLocalChannelHandle__anon9a8a71d40111::MessageState46 bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47 if (index < 0) {
48 *handle = LocalChannelHandle{nullptr, index};
49 } else if (static_cast<size_t>(index) < request.channels.size()) {
50 auto& channel_info = request.channels[index];
51 *handle = ChannelManager::Get().CreateHandle(
52 std::move(channel_info.data_fd),
53 std::move(channel_info.pollin_event_fd),
54 std::move(channel_info.pollhup_event_fd));
55 } else {
56 return false;
57 }
58 return true;
59 }
60
PushFileHandle__anon9a8a71d40111::MessageState61 Status<FileReference> PushFileHandle(BorrowedHandle handle) {
62 if (!handle)
63 return handle.Get();
64 response.file_descriptors.push_back(std::move(handle));
65 return response.file_descriptors.size() - 1;
66 }
67
PushChannelHandle__anon9a8a71d40111::MessageState68 Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
69 if (!handle)
70 return handle.value();
71
72 if (auto* channel_data =
73 ChannelManager::Get().GetChannelData(handle.value())) {
74 ChannelInfo<BorrowedHandle> channel_info{
75 channel_data->data_fd(), channel_data->pollin_event_fd(),
76 channel_data->pollhup_event_fd()};
77 response.channels.push_back(std::move(channel_info));
78 return response.channels.size() - 1;
79 } else {
80 return ErrorStatus{EINVAL};
81 }
82 }
83
PushChannelHandle__anon9a8a71d40111::MessageState84 Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
85 BorrowedHandle pollin_event_fd,
86 BorrowedHandle pollhup_event_fd) {
87 if (!data_fd || !pollin_event_fd || !pollhup_event_fd)
88 return ErrorStatus{EINVAL};
89 ChannelInfo<BorrowedHandle> channel_info{std::move(data_fd),
90 std::move(pollin_event_fd),
91 std::move(pollhup_event_fd)};
92 response.channels.push_back(std::move(channel_info));
93 return response.channels.size() - 1;
94 }
95
WriteData__anon9a8a71d40111::MessageState96 Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
97 size_t size = 0;
98 for (size_t i = 0; i < vector_length; i++) {
99 const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
100 response_data.insert(response_data.end(), data, data + vector[i].iov_len);
101 size += vector[i].iov_len;
102 }
103 return size;
104 }
105
ReadData__anon9a8a71d40111::MessageState106 Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
107 size_t size_remaining = request_data.size() - request_data_read_pos;
108 size_t size = 0;
109 for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
110 size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
111 memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
112 size_to_copy);
113 size += size_to_copy;
114 request_data_read_pos += size_to_copy;
115 size_remaining -= size_to_copy;
116 }
117 return size;
118 }
119
120 android::pdx::uds::RequestHeader<LocalHandle> request;
121 android::pdx::uds::ResponseHeader<BorrowedHandle> response;
122 std::vector<LocalHandle> sockets_to_close;
123 std::vector<uint8_t> request_data;
124 size_t request_data_read_pos{0};
125 std::vector<uint8_t> response_data;
126 };
127
128 } // anonymous namespace
129
130 namespace android {
131 namespace pdx {
132 namespace uds {
133
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)134 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
135 bool use_init_socket_fd)
136 : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
137 is_blocking_{blocking} {
138 LocalHandle fd;
139 if (use_init_socket_fd) {
140 // Cut off the /dev/socket/ prefix from the full socket path and use the
141 // resulting "name" to retrieve the file descriptor for the socket created
142 // by the init process.
143 constexpr char prefix[] = "/dev/socket/";
144 CHECK(android::base::StartsWith(endpoint_path_, prefix))
145 << "Endpoint::Endpoint: Socket name '" << endpoint_path_
146 << "' must begin with '" << prefix << "'";
147 std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
148 fd.Reset(android_get_control_socket(socket_name.c_str()));
149 CHECK(fd.IsValid())
150 << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
151 << socket_name << "'";
152 fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
153 } else {
154 fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
155 CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
156 << strerror(errno);
157
158 sockaddr_un local;
159 local.sun_family = AF_UNIX;
160 strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
161 local.sun_path[sizeof(local.sun_path) - 1] = '\0';
162
163 unlink(local.sun_path);
164 int ret =
165 bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
166 CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
167 }
168 Init(std::move(fd));
169 }
170
Endpoint(LocalHandle socket_fd)171 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
172
Init(LocalHandle socket_fd)173 void Endpoint::Init(LocalHandle socket_fd) {
174 if (socket_fd) {
175 CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
176 << "Endpoint::Endpoint: listen error: " << strerror(errno);
177 }
178 cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
179 CHECK(cancel_event_fd_.IsValid())
180 << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
181
182 epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
183 CHECK(epoll_fd_.IsValid())
184 << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
185
186 if (socket_fd) {
187 epoll_event socket_event;
188 socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
189 socket_event.data.fd = socket_fd.Get();
190 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
191 &socket_event);
192 CHECK_EQ(ret, 0)
193 << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
194 << strerror(errno);
195 }
196
197 epoll_event cancel_event;
198 cancel_event.events = EPOLLIN;
199 cancel_event.data.fd = cancel_event_fd_.Get();
200
201 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
202 &cancel_event);
203 CHECK_EQ(ret, 0)
204 << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
205 << strerror(errno);
206 socket_fd_ = std::move(socket_fd);
207 }
208
AllocateMessageState()209 void* Endpoint::AllocateMessageState() { return new MessageState; }
210
FreeMessageState(void * state)211 void Endpoint::FreeMessageState(void* state) {
212 delete static_cast<MessageState*>(state);
213 }
214
AcceptConnection(Message * message)215 Status<void> Endpoint::AcceptConnection(Message* message) {
216 if (!socket_fd_)
217 return ErrorStatus(EBADF);
218
219 sockaddr_un remote;
220 socklen_t addrlen = sizeof(remote);
221 LocalHandle connection_fd{accept4(socket_fd_.Get(),
222 reinterpret_cast<sockaddr*>(&remote),
223 &addrlen, SOCK_CLOEXEC)};
224 if (!connection_fd) {
225 ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
226 strerror(errno));
227 return ErrorStatus(errno);
228 }
229
230 LocalHandle local_socket;
231 LocalHandle remote_socket;
232 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
233 if (!status)
234 return status;
235
236 // Borrow the local channel handle before we move it into OnNewChannel().
237 BorrowedHandle channel_handle = local_socket.Borrow();
238 status = OnNewChannel(std::move(local_socket));
239 if (!status)
240 return status;
241
242 // Send the channel socket fd to the client.
243 ChannelConnectionInfo<LocalHandle> connection_info;
244 connection_info.channel_fd = std::move(remote_socket);
245 status = SendData(connection_fd.Borrow(), connection_info);
246
247 if (status) {
248 // Get the CHANNEL_OPEN message from client over the channel socket.
249 status = ReceiveMessageForChannel(channel_handle, message);
250 } else {
251 CloseChannel(GetChannelId(channel_handle));
252 }
253
254 // Don't need the connection socket anymore. Further communication should
255 // happen over the channel socket.
256 shutdown(connection_fd.Get(), SHUT_WR);
257 return status;
258 }
259
SetService(Service * service)260 Status<void> Endpoint::SetService(Service* service) {
261 service_ = service;
262 return {};
263 }
264
SetChannel(int channel_id,Channel * channel)265 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
266 std::lock_guard<std::mutex> autolock(channel_mutex_);
267 auto channel_data = channels_.find(channel_id);
268 if (channel_data == channels_.end())
269 return ErrorStatus{EINVAL};
270 channel_data->second.channel_state = channel;
271 return {};
272 }
273
OnNewChannel(LocalHandle channel_fd)274 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
275 std::lock_guard<std::mutex> autolock(channel_mutex_);
276 Status<void> status;
277 status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
278 return status;
279 }
280
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)281 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
282 LocalHandle channel_fd, Channel* channel_state) {
283 epoll_event event;
284 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
285 event.data.fd = channel_fd.Get();
286 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
287 ALOGE(
288 "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
289 strerror(errno));
290 return ErrorStatus(errno);
291 }
292 ChannelData channel_data;
293 channel_data.data_fd = std::move(channel_fd);
294 channel_data.channel_state = channel_state;
295 for (;;) {
296 // Try new channel IDs until we find one which is not already in the map.
297 if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
298 last_channel_id_ = 1;
299 auto iter = channels_.lower_bound(last_channel_id_);
300 if (iter == channels_.end() || iter->first != last_channel_id_) {
301 channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
302 iter = channels_.emplace_hint(iter, last_channel_id_,
303 std::move(channel_data));
304 return std::make_pair(last_channel_id_, &iter->second);
305 }
306 }
307 }
308
ReenableEpollEvent(const BorrowedHandle & fd)309 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
310 epoll_event event;
311 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
312 event.data.fd = fd.Get();
313 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
314 ALOGE(
315 "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
316 "endpoint: %s\n",
317 strerror(errno));
318 return ErrorStatus(errno);
319 }
320 return {};
321 }
322
CloseChannel(int channel_id)323 Status<void> Endpoint::CloseChannel(int channel_id) {
324 std::lock_guard<std::mutex> autolock(channel_mutex_);
325 return CloseChannelLocked(channel_id);
326 }
327
CloseChannelLocked(int32_t channel_id)328 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
329 ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
330
331 auto iter = channels_.find(channel_id);
332 if (iter == channels_.end())
333 return ErrorStatus{EINVAL};
334
335 int channel_fd = iter->second.data_fd.Get();
336 Status<void> status;
337 epoll_event ee; // See BUGS in man 2 epoll_ctl.
338 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &ee) < 0) {
339 status.SetError(errno);
340 ALOGE(
341 "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
342 "%s\n",
343 strerror(errno));
344 } else {
345 status.SetValue();
346 }
347
348 channel_fd_to_id_.erase(channel_fd);
349 channels_.erase(iter);
350 return status;
351 }
352
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)353 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
354 int set_mask) {
355 std::lock_guard<std::mutex> autolock(channel_mutex_);
356
357 auto search = channels_.find(channel_id);
358 if (search != channels_.end()) {
359 auto& channel_data = search->second;
360 channel_data.event_set.ModifyEvents(clear_mask, set_mask);
361 return {};
362 }
363
364 return ErrorStatus{EINVAL};
365 }
366
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)367 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
368 LocalHandle* remote_socket) {
369 Status<void> status;
370 char* endpoint_context = nullptr;
371 // Make sure the channel socket has the correct SELinux label applied.
372 // Here we get the label from the endpoint file descriptor, which should be
373 // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
374 // "endpoint" with "channel" to produce the channel label such as this:
375 // "u:object_r:pdx_service_channel_socket:s0".
376 if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
377 std::string channel_context = endpoint_context;
378 freecon(endpoint_context);
379 const std::string suffix = "_endpoint_socket";
380 auto pos = channel_context.find(suffix);
381 if (pos != std::string::npos) {
382 channel_context.replace(pos, suffix.size(), "_channel_socket");
383 } else {
384 ALOGW(
385 "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
386 "does not contain expected substring '%s'",
387 channel_context.c_str(), suffix.c_str());
388 }
389 ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
390 "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
391 "security context: %s",
392 strerror(errno));
393 } else {
394 ALOGE(
395 "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
396 "socket's security context: %s",
397 strerror(errno));
398 }
399
400 int channel_pair[2] = {};
401 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
402 ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
403 strerror(errno));
404 status.SetError(errno);
405 return status;
406 }
407
408 setsockcreatecon_raw(nullptr);
409
410 local_socket->Reset(channel_pair[0]);
411 remote_socket->Reset(channel_pair[1]);
412
413 int optval = 1;
414 if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
415 sizeof(optval)) == -1) {
416 ALOGE(
417 "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
418 "the credentials for channel %d: %s",
419 local_socket->Get(), strerror(errno));
420 status.SetError(errno);
421 }
422 return status;
423 }
424
PushChannel(Message * message,int,Channel * channel,int * channel_id)425 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
426 int /*flags*/,
427 Channel* channel,
428 int* channel_id) {
429 LocalHandle local_socket;
430 LocalHandle remote_socket;
431 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
432 if (!status)
433 return status.error_status();
434
435 std::lock_guard<std::mutex> autolock(channel_mutex_);
436 auto channel_data_status =
437 OnNewChannelLocked(std::move(local_socket), channel);
438 if (!channel_data_status)
439 return channel_data_status.error_status();
440
441 ChannelData* channel_data;
442 std::tie(*channel_id, channel_data) = channel_data_status.take();
443
444 // Flags are ignored for now.
445 // TODO(xiaohuit): Implement those.
446
447 auto* state = static_cast<MessageState*>(message->GetState());
448 Status<ChannelReference> ref = state->PushChannelHandle(
449 remote_socket.Borrow(), channel_data->event_set.pollin_event_fd(),
450 channel_data->event_set.pollhup_event_fd());
451 if (!ref)
452 return ref.error_status();
453 state->sockets_to_close.push_back(std::move(remote_socket));
454 return RemoteChannelHandle{ref.get()};
455 }
456
CheckChannel(const Message *,ChannelReference,Channel **)457 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
458 ChannelReference /*ref*/,
459 Channel** /*channel*/) {
460 // TODO(xiaohuit): Implement this.
461 return ErrorStatus(EFAULT);
462 }
463
GetChannelState(int32_t channel_id)464 Channel* Endpoint::GetChannelState(int32_t channel_id) {
465 std::lock_guard<std::mutex> autolock(channel_mutex_);
466 auto channel_data = channels_.find(channel_id);
467 return (channel_data != channels_.end()) ? channel_data->second.channel_state
468 : nullptr;
469 }
470
GetChannelSocketFd(int32_t channel_id)471 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
472 std::lock_guard<std::mutex> autolock(channel_mutex_);
473 BorrowedHandle handle;
474 auto channel_data = channels_.find(channel_id);
475 if (channel_data != channels_.end())
476 handle = channel_data->second.data_fd.Borrow();
477 return handle;
478 }
479
GetChannelEventFd(int32_t channel_id)480 Status<std::pair<BorrowedHandle, BorrowedHandle>> Endpoint::GetChannelEventFd(
481 int32_t channel_id) {
482 std::lock_guard<std::mutex> autolock(channel_mutex_);
483 auto channel_data = channels_.find(channel_id);
484 if (channel_data != channels_.end()) {
485 return {{channel_data->second.event_set.pollin_event_fd(),
486 channel_data->second.event_set.pollhup_event_fd()}};
487 }
488 return ErrorStatus(ENOENT);
489 }
490
GetChannelId(const BorrowedHandle & channel_fd)491 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
492 std::lock_guard<std::mutex> autolock(channel_mutex_);
493 auto iter = channel_fd_to_id_.find(channel_fd.Get());
494 return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
495 }
496
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)497 Status<void> Endpoint::ReceiveMessageForChannel(
498 const BorrowedHandle& channel_fd, Message* message) {
499 RequestHeader<LocalHandle> request;
500 int32_t channel_id = GetChannelId(channel_fd);
501 auto status = ReceiveData(channel_fd.Borrow(), &request);
502 if (!status) {
503 if (status.error() == ESHUTDOWN) {
504 BuildCloseMessage(channel_id, message);
505 return {};
506 } else {
507 CloseChannel(channel_id);
508 return status;
509 }
510 }
511
512 MessageInfo info;
513 info.pid = request.cred.pid;
514 info.tid = -1;
515 info.cid = channel_id;
516 info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
517 : GetNextAvailableMessageId();
518 info.euid = request.cred.uid;
519 info.egid = request.cred.gid;
520 info.op = request.op;
521 info.flags = 0;
522 info.service = service_;
523 info.channel = GetChannelState(channel_id);
524 if (info.channel != nullptr) {
525 info.channel->SetActiveProcessId(request.cred.pid);
526 }
527 info.send_len = request.send_len;
528 info.recv_len = request.max_recv_len;
529 info.fd_count = request.file_descriptors.size();
530 static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
531 "Impulse payload sizes must be the same in RequestHeader and "
532 "MessageInfo");
533 memcpy(info.impulse, request.impulse_payload.data(),
534 request.impulse_payload.size());
535 *message = Message{info};
536 auto* state = static_cast<MessageState*>(message->GetState());
537 state->request = std::move(request);
538 if (state->request.send_len > 0 && !state->request.is_impulse) {
539 state->request_data.resize(state->request.send_len);
540 status = ReceiveData(channel_fd, state->request_data.data(),
541 state->request_data.size());
542 }
543
544 if (status && state->request.is_impulse)
545 status = ReenableEpollEvent(channel_fd);
546
547 if (!status) {
548 if (status.error() == ESHUTDOWN) {
549 BuildCloseMessage(channel_id, message);
550 return {};
551 } else {
552 CloseChannel(channel_id);
553 return status;
554 }
555 }
556
557 return status;
558 }
559
BuildCloseMessage(int32_t channel_id,Message * message)560 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
561 ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
562 MessageInfo info;
563 info.pid = -1;
564 info.tid = -1;
565 info.cid = channel_id;
566 info.mid = GetNextAvailableMessageId();
567 info.euid = -1;
568 info.egid = -1;
569 info.op = opcodes::CHANNEL_CLOSE;
570 info.flags = 0;
571 info.service = service_;
572 info.channel = GetChannelState(channel_id);
573 info.send_len = 0;
574 info.recv_len = 0;
575 info.fd_count = 0;
576 *message = Message{info};
577 }
578
MessageReceive(Message * message)579 Status<void> Endpoint::MessageReceive(Message* message) {
580 // Receive at most one event from the epoll set. This should prevent multiple
581 // dispatch threads from attempting to handle messages on the same socket at
582 // the same time.
583 epoll_event event;
584 int count = RETRY_EINTR(
585 epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
586 if (count < 0) {
587 ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
588 strerror(errno));
589 return ErrorStatus{errno};
590 } else if (count == 0) {
591 return ErrorStatus{ETIMEDOUT};
592 }
593
594 if (event.data.fd == cancel_event_fd_.Get()) {
595 return ErrorStatus{ESHUTDOWN};
596 }
597
598 if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
599 auto status = AcceptConnection(message);
600 auto reenable_status = ReenableEpollEvent(socket_fd_.Borrow());
601 if (!reenable_status)
602 return reenable_status;
603 return status;
604 }
605
606 BorrowedHandle channel_fd{event.data.fd};
607 return ReceiveMessageForChannel(channel_fd, message);
608 }
609
MessageReply(Message * message,int return_code)610 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
611 const int32_t channel_id = message->GetChannelId();
612 auto channel_socket = GetChannelSocketFd(channel_id);
613 if (!channel_socket)
614 return ErrorStatus{EBADF};
615
616 auto* state = static_cast<MessageState*>(message->GetState());
617 switch (message->GetOp()) {
618 case opcodes::CHANNEL_CLOSE:
619 return CloseChannel(channel_id);
620
621 case opcodes::CHANNEL_OPEN:
622 if (return_code < 0) {
623 return CloseChannel(channel_id);
624 } else {
625 // Open messages do not have a payload and may not transfer any channels
626 // or file descriptors on behalf of the service.
627 state->response_data.clear();
628 state->response.file_descriptors.clear();
629 state->response.channels.clear();
630
631 // Return the channel event-related fds in a single ChannelInfo entry
632 // with an empty data_fd member.
633 auto status = GetChannelEventFd(channel_id);
634 if (!status)
635 return status.error_status();
636
637 auto handles = status.take();
638 state->response.channels.push_back({BorrowedHandle(),
639 std::move(handles.first),
640 std::move(handles.second)});
641 return_code = 0;
642 }
643 break;
644 }
645
646 state->response.ret_code = return_code;
647 state->response.recv_len = state->response_data.size();
648 auto status = SendData(channel_socket, state->response);
649 if (status && !state->response_data.empty()) {
650 status = SendData(channel_socket, state->response_data.data(),
651 state->response_data.size());
652 }
653
654 if (status)
655 status = ReenableEpollEvent(channel_socket);
656
657 return status;
658 }
659
MessageReplyFd(Message * message,unsigned int push_fd)660 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
661 auto* state = static_cast<MessageState*>(message->GetState());
662 auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
663 if (!ref)
664 return ref.error_status();
665 return MessageReply(message, ref.get());
666 }
667
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)668 Status<void> Endpoint::MessageReplyChannelHandle(
669 Message* message, const LocalChannelHandle& handle) {
670 auto* state = static_cast<MessageState*>(message->GetState());
671 auto ref = state->PushChannelHandle(handle.Borrow());
672 if (!ref)
673 return ref.error_status();
674 return MessageReply(message, ref.get());
675 }
676
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)677 Status<void> Endpoint::MessageReplyChannelHandle(
678 Message* message, const BorrowedChannelHandle& handle) {
679 auto* state = static_cast<MessageState*>(message->GetState());
680 auto ref = state->PushChannelHandle(handle.Duplicate());
681 if (!ref)
682 return ref.error_status();
683 return MessageReply(message, ref.get());
684 }
685
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)686 Status<void> Endpoint::MessageReplyChannelHandle(
687 Message* message, const RemoteChannelHandle& handle) {
688 return MessageReply(message, handle.value());
689 }
690
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)691 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
692 size_t vector_length) {
693 auto* state = static_cast<MessageState*>(message->GetState());
694 return state->ReadData(vector, vector_length);
695 }
696
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)697 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
698 size_t vector_length) {
699 auto* state = static_cast<MessageState*>(message->GetState());
700 return state->WriteData(vector, vector_length);
701 }
702
PushFileHandle(Message * message,const LocalHandle & handle)703 Status<FileReference> Endpoint::PushFileHandle(Message* message,
704 const LocalHandle& handle) {
705 auto* state = static_cast<MessageState*>(message->GetState());
706 return state->PushFileHandle(handle.Borrow());
707 }
708
PushFileHandle(Message * message,const BorrowedHandle & handle)709 Status<FileReference> Endpoint::PushFileHandle(Message* message,
710 const BorrowedHandle& handle) {
711 auto* state = static_cast<MessageState*>(message->GetState());
712 return state->PushFileHandle(handle.Duplicate());
713 }
714
PushFileHandle(Message *,const RemoteHandle & handle)715 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
716 const RemoteHandle& handle) {
717 return handle.Get();
718 }
719
PushChannelHandle(Message * message,const LocalChannelHandle & handle)720 Status<ChannelReference> Endpoint::PushChannelHandle(
721 Message* message, const LocalChannelHandle& handle) {
722 auto* state = static_cast<MessageState*>(message->GetState());
723 return state->PushChannelHandle(handle.Borrow());
724 }
725
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)726 Status<ChannelReference> Endpoint::PushChannelHandle(
727 Message* message, const BorrowedChannelHandle& handle) {
728 auto* state = static_cast<MessageState*>(message->GetState());
729 return state->PushChannelHandle(handle.Duplicate());
730 }
731
PushChannelHandle(Message *,const RemoteChannelHandle & handle)732 Status<ChannelReference> Endpoint::PushChannelHandle(
733 Message* /*message*/, const RemoteChannelHandle& handle) {
734 return handle.value();
735 }
736
GetFileHandle(Message * message,FileReference ref) const737 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
738 LocalHandle handle;
739 auto* state = static_cast<MessageState*>(message->GetState());
740 state->GetLocalFileHandle(ref, &handle);
741 return handle;
742 }
743
GetChannelHandle(Message * message,ChannelReference ref) const744 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
745 ChannelReference ref) const {
746 LocalChannelHandle handle;
747 auto* state = static_cast<MessageState*>(message->GetState());
748 state->GetLocalChannelHandle(ref, &handle);
749 return handle;
750 }
751
Cancel()752 Status<void> Endpoint::Cancel() {
753 if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
754 return ErrorStatus{errno};
755 return {};
756 }
757
Create(const std::string & endpoint_path,mode_t,bool blocking)758 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
759 mode_t /*unused_mode*/,
760 bool blocking) {
761 return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
762 }
763
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)764 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
765 const std::string& endpoint_path, bool blocking) {
766 return std::unique_ptr<Endpoint>(
767 new Endpoint(endpoint_path, blocking, false));
768 }
769
CreateFromSocketFd(LocalHandle socket_fd)770 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
771 return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
772 }
773
RegisterNewChannelForTests(LocalHandle channel_fd)774 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
775 int optval = 1;
776 if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
777 sizeof(optval)) == -1) {
778 ALOGE(
779 "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
780 "of the credentials for channel %d: %s",
781 channel_fd.Get(), strerror(errno));
782 return ErrorStatus(errno);
783 }
784 return OnNewChannel(std::move(channel_fd));
785 }
786
787 } // namespace uds
788 } // namespace pdx
789 } // namespace android
790