1 #include "uds/service_endpoint.h"
2
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm> // std::min
9
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18
19 namespace {
20
21 constexpr int kMaxBackLogForSocketListen = 1;
22
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33
34 struct MessageState {
GetLocalFileHandle__anon9a8a71d40111::MessageState35 bool GetLocalFileHandle(int index, LocalHandle* handle) {
36 if (index < 0) {
37 handle->Reset(index);
38 } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39 *handle = std::move(request.file_descriptors[index]);
40 } else {
41 return false;
42 }
43 return true;
44 }
45
GetLocalChannelHandle__anon9a8a71d40111::MessageState46 bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47 if (index < 0) {
48 *handle = LocalChannelHandle{nullptr, index};
49 } else if (static_cast<size_t>(index) < request.channels.size()) {
50 auto& channel_info = request.channels[index];
51 *handle = ChannelManager::Get().CreateHandle(
52 std::move(channel_info.data_fd),
53 std::move(channel_info.pollin_event_fd),
54 std::move(channel_info.pollhup_event_fd));
55 } else {
56 return false;
57 }
58 return true;
59 }
60
PushFileHandle__anon9a8a71d40111::MessageState61 Status<FileReference> PushFileHandle(BorrowedHandle handle) {
62 if (!handle)
63 return handle.Get();
64 response.file_descriptors.push_back(std::move(handle));
65 return response.file_descriptors.size() - 1;
66 }
67
PushChannelHandle__anon9a8a71d40111::MessageState68 Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
69 if (!handle)
70 return handle.value();
71
72 if (auto* channel_data =
73 ChannelManager::Get().GetChannelData(handle.value())) {
74 ChannelInfo<BorrowedHandle> channel_info{
75 channel_data->data_fd(), channel_data->pollin_event_fd(),
76 channel_data->pollhup_event_fd()};
77 response.channels.push_back(std::move(channel_info));
78 return response.channels.size() - 1;
79 } else {
80 return ErrorStatus{EINVAL};
81 }
82 }
83
PushChannelHandle__anon9a8a71d40111::MessageState84 Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
85 BorrowedHandle pollin_event_fd,
86 BorrowedHandle pollhup_event_fd) {
87 if (!data_fd || !pollin_event_fd || !pollhup_event_fd)
88 return ErrorStatus{EINVAL};
89 ChannelInfo<BorrowedHandle> channel_info{std::move(data_fd),
90 std::move(pollin_event_fd),
91 std::move(pollhup_event_fd)};
92 response.channels.push_back(std::move(channel_info));
93 return response.channels.size() - 1;
94 }
95
WriteData__anon9a8a71d40111::MessageState96 Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
97 size_t size = 0;
98 for (size_t i = 0; i < vector_length; i++) {
99 const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
100 response_data.insert(response_data.end(), data, data + vector[i].iov_len);
101 size += vector[i].iov_len;
102 }
103 return size;
104 }
105
ReadData__anon9a8a71d40111::MessageState106 Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
107 size_t size_remaining = request_data.size() - request_data_read_pos;
108 size_t size = 0;
109 for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
110 size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
111 memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
112 size_to_copy);
113 size += size_to_copy;
114 request_data_read_pos += size_to_copy;
115 size_remaining -= size_to_copy;
116 }
117 return size;
118 }
119
120 android::pdx::uds::RequestHeader<LocalHandle> request;
121 android::pdx::uds::ResponseHeader<BorrowedHandle> response;
122 std::vector<LocalHandle> sockets_to_close;
123 std::vector<uint8_t> request_data;
124 size_t request_data_read_pos{0};
125 std::vector<uint8_t> response_data;
126 };
127
128 } // anonymous namespace
129
130 namespace android {
131 namespace pdx {
132 namespace uds {
133
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)134 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
135 bool use_init_socket_fd)
136 : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
137 is_blocking_{blocking} {
138 LocalHandle fd;
139 if (use_init_socket_fd) {
140 // Cut off the /dev/socket/ prefix from the full socket path and use the
141 // resulting "name" to retrieve the file descriptor for the socket created
142 // by the init process.
143 constexpr char prefix[] = "/dev/socket/";
144 CHECK(android::base::StartsWith(endpoint_path_, prefix))
145 << "Endpoint::Endpoint: Socket name '" << endpoint_path_
146 << "' must begin with '" << prefix << "'";
147 std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
148 fd.Reset(android_get_control_socket(socket_name.c_str()));
149 CHECK(fd.IsValid())
150 << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
151 << socket_name << "'";
152 fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
153 } else {
154 fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
155 CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
156 << strerror(errno);
157
158 sockaddr_un local;
159 local.sun_family = AF_UNIX;
160 strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
161 local.sun_path[sizeof(local.sun_path) - 1] = '\0';
162
163 unlink(local.sun_path);
164 int ret =
165 bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
166 CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
167 }
168 Init(std::move(fd));
169 }
170
Endpoint(LocalHandle socket_fd)171 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
172
Init(LocalHandle socket_fd)173 void Endpoint::Init(LocalHandle socket_fd) {
174 if (socket_fd) {
175 CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
176 << "Endpoint::Endpoint: listen error: " << strerror(errno);
177 }
178 cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
179 CHECK(cancel_event_fd_.IsValid())
180 << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
181
182 epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
183 CHECK(epoll_fd_.IsValid())
184 << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
185
186 if (socket_fd) {
187 epoll_event socket_event;
188 socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
189 socket_event.data.fd = socket_fd.Get();
190 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
191 &socket_event);
192 CHECK_EQ(ret, 0)
193 << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
194 << strerror(errno);
195 }
196
197 epoll_event cancel_event;
198 cancel_event.events = EPOLLIN;
199 cancel_event.data.fd = cancel_event_fd_.Get();
200
201 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
202 &cancel_event);
203 CHECK_EQ(ret, 0)
204 << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
205 << strerror(errno);
206 socket_fd_ = std::move(socket_fd);
207 }
208
AllocateMessageState()209 void* Endpoint::AllocateMessageState() { return new MessageState; }
210
FreeMessageState(void * state)211 void Endpoint::FreeMessageState(void* state) {
212 delete static_cast<MessageState*>(state);
213 }
214
AcceptConnection(Message * message)215 Status<void> Endpoint::AcceptConnection(Message* message) {
216 if (!socket_fd_)
217 return ErrorStatus(EBADF);
218
219 sockaddr_un remote;
220 socklen_t addrlen = sizeof(remote);
221 LocalHandle connection_fd{accept4(socket_fd_.Get(),
222 reinterpret_cast<sockaddr*>(&remote),
223 &addrlen, SOCK_CLOEXEC)};
224 if (!connection_fd) {
225 ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
226 strerror(errno));
227 return ErrorStatus(errno);
228 }
229
230 LocalHandle local_socket;
231 LocalHandle remote_socket;
232 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
233 if (!status)
234 return status;
235
236 // Borrow the local channel handle before we move it into OnNewChannel().
237 BorrowedHandle channel_handle = local_socket.Borrow();
238 status = OnNewChannel(std::move(local_socket));
239 if (!status)
240 return status;
241
242 // Send the channel socket fd to the client.
243 ChannelConnectionInfo<LocalHandle> connection_info;
244 connection_info.channel_fd = std::move(remote_socket);
245 status = SendData(connection_fd.Borrow(), connection_info);
246
247 if (status) {
248 // Get the CHANNEL_OPEN message from client over the channel socket.
249 status = ReceiveMessageForChannel(channel_handle, message);
250 } else {
251 CloseChannel(GetChannelId(channel_handle));
252 }
253
254 // Don't need the connection socket anymore. Further communication should
255 // happen over the channel socket.
256 shutdown(connection_fd.Get(), SHUT_WR);
257 return status;
258 }
259
SetService(Service * service)260 Status<void> Endpoint::SetService(Service* service) {
261 service_ = service;
262 return {};
263 }
264
SetChannel(int channel_id,Channel * channel)265 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
266 std::lock_guard<std::mutex> autolock(channel_mutex_);
267 auto channel_data = channels_.find(channel_id);
268 if (channel_data == channels_.end())
269 return ErrorStatus{EINVAL};
270 channel_data->second.channel_state = channel;
271 return {};
272 }
273
OnNewChannel(LocalHandle channel_fd)274 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
275 std::lock_guard<std::mutex> autolock(channel_mutex_);
276 Status<void> status;
277 status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
278 return status;
279 }
280
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)281 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
282 LocalHandle channel_fd, Channel* channel_state) {
283 epoll_event event;
284 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
285 event.data.fd = channel_fd.Get();
286 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
287 ALOGE(
288 "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
289 strerror(errno));
290 return ErrorStatus(errno);
291 }
292 ChannelData channel_data;
293 channel_data.data_fd = std::move(channel_fd);
294 channel_data.channel_state = channel_state;
295 for (;;) {
296 // Try new channel IDs until we find one which is not already in the map.
297 if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
298 last_channel_id_ = 1;
299 auto iter = channels_.lower_bound(last_channel_id_);
300 if (iter == channels_.end() || iter->first != last_channel_id_) {
301 channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
302 iter = channels_.emplace_hint(iter, last_channel_id_,
303 std::move(channel_data));
304 return std::make_pair(last_channel_id_, &iter->second);
305 }
306 }
307 }
308
ReenableEpollEvent(const BorrowedHandle & fd)309 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
310 epoll_event event;
311 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
312 event.data.fd = fd.Get();
313 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
314 ALOGE(
315 "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
316 "endpoint: %s\n",
317 strerror(errno));
318 return ErrorStatus(errno);
319 }
320 return {};
321 }
322
CloseChannel(int channel_id)323 Status<void> Endpoint::CloseChannel(int channel_id) {
324 std::lock_guard<std::mutex> autolock(channel_mutex_);
325 return CloseChannelLocked(channel_id);
326 }
327
CloseChannelLocked(int32_t channel_id)328 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
329 ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
330
331 auto iter = channels_.find(channel_id);
332 if (iter == channels_.end())
333 return ErrorStatus{EINVAL};
334
335 int channel_fd = iter->second.data_fd.Get();
336 Status<void> status;
337 epoll_event dummy; // See BUGS in man 2 epoll_ctl.
338 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) {
339 status.SetError(errno);
340 ALOGE(
341 "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
342 "%s\n",
343 strerror(errno));
344 } else {
345 status.SetValue();
346 }
347
348 channel_fd_to_id_.erase(channel_fd);
349 channels_.erase(iter);
350 return status;
351 }
352
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)353 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
354 int set_mask) {
355 std::lock_guard<std::mutex> autolock(channel_mutex_);
356
357 auto search = channels_.find(channel_id);
358 if (search != channels_.end()) {
359 auto& channel_data = search->second;
360 channel_data.event_set.ModifyEvents(clear_mask, set_mask);
361 return {};
362 }
363
364 return ErrorStatus{EINVAL};
365 }
366
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)367 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
368 LocalHandle* remote_socket) {
369 Status<void> status;
370 char* endpoint_context = nullptr;
371 // Make sure the channel socket has the correct SELinux label applied.
372 // Here we get the label from the endpoint file descriptor, which should be
373 // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
374 // "endpoint" with "channel" to produce the channel label such as this:
375 // "u:object_r:pdx_service_channel_socket:s0".
376 if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
377 std::string channel_context = endpoint_context;
378 freecon(endpoint_context);
379 const std::string suffix = "_endpoint_socket";
380 auto pos = channel_context.find(suffix);
381 if (pos != std::string::npos) {
382 channel_context.replace(pos, suffix.size(), "_channel_socket");
383 } else {
384 ALOGW(
385 "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
386 "does not contain expected substring '%s'",
387 channel_context.c_str(), suffix.c_str());
388 }
389 ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
390 "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
391 "security context: %s",
392 strerror(errno));
393 } else {
394 ALOGE(
395 "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
396 "socket's security context: %s",
397 strerror(errno));
398 }
399
400 int channel_pair[2] = {};
401 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
402 ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
403 strerror(errno));
404 status.SetError(errno);
405 return status;
406 }
407
408 setsockcreatecon_raw(nullptr);
409
410 local_socket->Reset(channel_pair[0]);
411 remote_socket->Reset(channel_pair[1]);
412
413 int optval = 1;
414 if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
415 sizeof(optval)) == -1) {
416 ALOGE(
417 "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
418 "the credentials for channel %d: %s",
419 local_socket->Get(), strerror(errno));
420 status.SetError(errno);
421 }
422 return status;
423 }
424
PushChannel(Message * message,int,Channel * channel,int * channel_id)425 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
426 int /*flags*/,
427 Channel* channel,
428 int* channel_id) {
429 LocalHandle local_socket;
430 LocalHandle remote_socket;
431 auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
432 if (!status)
433 return status.error_status();
434
435 std::lock_guard<std::mutex> autolock(channel_mutex_);
436 auto channel_data_status =
437 OnNewChannelLocked(std::move(local_socket), channel);
438 if (!channel_data_status)
439 return channel_data_status.error_status();
440
441 ChannelData* channel_data;
442 std::tie(*channel_id, channel_data) = channel_data_status.take();
443
444 // Flags are ignored for now.
445 // TODO(xiaohuit): Implement those.
446
447 auto* state = static_cast<MessageState*>(message->GetState());
448 Status<ChannelReference> ref = state->PushChannelHandle(
449 remote_socket.Borrow(), channel_data->event_set.pollin_event_fd(),
450 channel_data->event_set.pollhup_event_fd());
451 if (!ref)
452 return ref.error_status();
453 state->sockets_to_close.push_back(std::move(remote_socket));
454 return RemoteChannelHandle{ref.get()};
455 }
456
CheckChannel(const Message *,ChannelReference,Channel **)457 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
458 ChannelReference /*ref*/,
459 Channel** /*channel*/) {
460 // TODO(xiaohuit): Implement this.
461 return ErrorStatus(EFAULT);
462 }
463
GetChannelState(int32_t channel_id)464 Channel* Endpoint::GetChannelState(int32_t channel_id) {
465 std::lock_guard<std::mutex> autolock(channel_mutex_);
466 auto channel_data = channels_.find(channel_id);
467 return (channel_data != channels_.end()) ? channel_data->second.channel_state
468 : nullptr;
469 }
470
GetChannelSocketFd(int32_t channel_id)471 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
472 std::lock_guard<std::mutex> autolock(channel_mutex_);
473 BorrowedHandle handle;
474 auto channel_data = channels_.find(channel_id);
475 if (channel_data != channels_.end())
476 handle = channel_data->second.data_fd.Borrow();
477 return handle;
478 }
479
GetChannelEventFd(int32_t channel_id)480 Status<std::pair<BorrowedHandle, BorrowedHandle>> Endpoint::GetChannelEventFd(
481 int32_t channel_id) {
482 std::lock_guard<std::mutex> autolock(channel_mutex_);
483 auto channel_data = channels_.find(channel_id);
484 if (channel_data != channels_.end()) {
485 return {{channel_data->second.event_set.pollin_event_fd(),
486 channel_data->second.event_set.pollhup_event_fd()}};
487 }
488 return ErrorStatus(ENOENT);
489 }
490
GetChannelId(const BorrowedHandle & channel_fd)491 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
492 std::lock_guard<std::mutex> autolock(channel_mutex_);
493 auto iter = channel_fd_to_id_.find(channel_fd.Get());
494 return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
495 }
496
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)497 Status<void> Endpoint::ReceiveMessageForChannel(
498 const BorrowedHandle& channel_fd, Message* message) {
499 RequestHeader<LocalHandle> request;
500 int32_t channel_id = GetChannelId(channel_fd);
501 auto status = ReceiveData(channel_fd.Borrow(), &request);
502 if (!status) {
503 if (status.error() == ESHUTDOWN) {
504 BuildCloseMessage(channel_id, message);
505 return {};
506 } else {
507 CloseChannel(channel_id);
508 return status;
509 }
510 }
511
512 MessageInfo info;
513 info.pid = request.cred.pid;
514 info.tid = -1;
515 info.cid = channel_id;
516 info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
517 : GetNextAvailableMessageId();
518 info.euid = request.cred.uid;
519 info.egid = request.cred.gid;
520 info.op = request.op;
521 info.flags = 0;
522 info.service = service_;
523 info.channel = GetChannelState(channel_id);
524 info.send_len = request.send_len;
525 info.recv_len = request.max_recv_len;
526 info.fd_count = request.file_descriptors.size();
527 static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
528 "Impulse payload sizes must be the same in RequestHeader and "
529 "MessageInfo");
530 memcpy(info.impulse, request.impulse_payload.data(),
531 request.impulse_payload.size());
532 *message = Message{info};
533 auto* state = static_cast<MessageState*>(message->GetState());
534 state->request = std::move(request);
535 if (request.send_len > 0 && !request.is_impulse) {
536 state->request_data.resize(request.send_len);
537 status = ReceiveData(channel_fd, state->request_data.data(),
538 state->request_data.size());
539 }
540
541 if (status && request.is_impulse)
542 status = ReenableEpollEvent(channel_fd);
543
544 if (!status) {
545 if (status.error() == ESHUTDOWN) {
546 BuildCloseMessage(channel_id, message);
547 return {};
548 } else {
549 CloseChannel(channel_id);
550 return status;
551 }
552 }
553
554 return status;
555 }
556
BuildCloseMessage(int32_t channel_id,Message * message)557 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
558 ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
559 MessageInfo info;
560 info.pid = -1;
561 info.tid = -1;
562 info.cid = channel_id;
563 info.mid = GetNextAvailableMessageId();
564 info.euid = -1;
565 info.egid = -1;
566 info.op = opcodes::CHANNEL_CLOSE;
567 info.flags = 0;
568 info.service = service_;
569 info.channel = GetChannelState(channel_id);
570 info.send_len = 0;
571 info.recv_len = 0;
572 info.fd_count = 0;
573 *message = Message{info};
574 }
575
MessageReceive(Message * message)576 Status<void> Endpoint::MessageReceive(Message* message) {
577 // Receive at most one event from the epoll set. This should prevent multiple
578 // dispatch threads from attempting to handle messages on the same socket at
579 // the same time.
580 epoll_event event;
581 int count = RETRY_EINTR(
582 epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
583 if (count < 0) {
584 ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
585 strerror(errno));
586 return ErrorStatus{errno};
587 } else if (count == 0) {
588 return ErrorStatus{ETIMEDOUT};
589 }
590
591 if (event.data.fd == cancel_event_fd_.Get()) {
592 return ErrorStatus{ESHUTDOWN};
593 }
594
595 if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
596 auto status = AcceptConnection(message);
597 auto reenable_status = ReenableEpollEvent(socket_fd_.Borrow());
598 if (!reenable_status)
599 return reenable_status;
600 return status;
601 }
602
603 BorrowedHandle channel_fd{event.data.fd};
604 return ReceiveMessageForChannel(channel_fd, message);
605 }
606
MessageReply(Message * message,int return_code)607 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
608 const int32_t channel_id = message->GetChannelId();
609 auto channel_socket = GetChannelSocketFd(channel_id);
610 if (!channel_socket)
611 return ErrorStatus{EBADF};
612
613 auto* state = static_cast<MessageState*>(message->GetState());
614 switch (message->GetOp()) {
615 case opcodes::CHANNEL_CLOSE:
616 return CloseChannel(channel_id);
617
618 case opcodes::CHANNEL_OPEN:
619 if (return_code < 0) {
620 return CloseChannel(channel_id);
621 } else {
622 // Open messages do not have a payload and may not transfer any channels
623 // or file descriptors on behalf of the service.
624 state->response_data.clear();
625 state->response.file_descriptors.clear();
626 state->response.channels.clear();
627
628 // Return the channel event-related fds in a single ChannelInfo entry
629 // with an empty data_fd member.
630 auto status = GetChannelEventFd(channel_id);
631 if (!status)
632 return status.error_status();
633
634 auto handles = status.take();
635 state->response.channels.push_back({BorrowedHandle(),
636 std::move(handles.first),
637 std::move(handles.second)});
638 return_code = 0;
639 }
640 break;
641 }
642
643 state->response.ret_code = return_code;
644 state->response.recv_len = state->response_data.size();
645 auto status = SendData(channel_socket, state->response);
646 if (status && !state->response_data.empty()) {
647 status = SendData(channel_socket, state->response_data.data(),
648 state->response_data.size());
649 }
650
651 if (status)
652 status = ReenableEpollEvent(channel_socket);
653
654 return status;
655 }
656
MessageReplyFd(Message * message,unsigned int push_fd)657 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
658 auto* state = static_cast<MessageState*>(message->GetState());
659 auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
660 if (!ref)
661 return ref.error_status();
662 return MessageReply(message, ref.get());
663 }
664
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)665 Status<void> Endpoint::MessageReplyChannelHandle(
666 Message* message, const LocalChannelHandle& handle) {
667 auto* state = static_cast<MessageState*>(message->GetState());
668 auto ref = state->PushChannelHandle(handle.Borrow());
669 if (!ref)
670 return ref.error_status();
671 return MessageReply(message, ref.get());
672 }
673
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)674 Status<void> Endpoint::MessageReplyChannelHandle(
675 Message* message, const BorrowedChannelHandle& handle) {
676 auto* state = static_cast<MessageState*>(message->GetState());
677 auto ref = state->PushChannelHandle(handle.Duplicate());
678 if (!ref)
679 return ref.error_status();
680 return MessageReply(message, ref.get());
681 }
682
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)683 Status<void> Endpoint::MessageReplyChannelHandle(
684 Message* message, const RemoteChannelHandle& handle) {
685 return MessageReply(message, handle.value());
686 }
687
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)688 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
689 size_t vector_length) {
690 auto* state = static_cast<MessageState*>(message->GetState());
691 return state->ReadData(vector, vector_length);
692 }
693
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)694 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
695 size_t vector_length) {
696 auto* state = static_cast<MessageState*>(message->GetState());
697 return state->WriteData(vector, vector_length);
698 }
699
PushFileHandle(Message * message,const LocalHandle & handle)700 Status<FileReference> Endpoint::PushFileHandle(Message* message,
701 const LocalHandle& handle) {
702 auto* state = static_cast<MessageState*>(message->GetState());
703 return state->PushFileHandle(handle.Borrow());
704 }
705
PushFileHandle(Message * message,const BorrowedHandle & handle)706 Status<FileReference> Endpoint::PushFileHandle(Message* message,
707 const BorrowedHandle& handle) {
708 auto* state = static_cast<MessageState*>(message->GetState());
709 return state->PushFileHandle(handle.Duplicate());
710 }
711
PushFileHandle(Message *,const RemoteHandle & handle)712 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
713 const RemoteHandle& handle) {
714 return handle.Get();
715 }
716
PushChannelHandle(Message * message,const LocalChannelHandle & handle)717 Status<ChannelReference> Endpoint::PushChannelHandle(
718 Message* message, const LocalChannelHandle& handle) {
719 auto* state = static_cast<MessageState*>(message->GetState());
720 return state->PushChannelHandle(handle.Borrow());
721 }
722
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)723 Status<ChannelReference> Endpoint::PushChannelHandle(
724 Message* message, const BorrowedChannelHandle& handle) {
725 auto* state = static_cast<MessageState*>(message->GetState());
726 return state->PushChannelHandle(handle.Duplicate());
727 }
728
PushChannelHandle(Message *,const RemoteChannelHandle & handle)729 Status<ChannelReference> Endpoint::PushChannelHandle(
730 Message* /*message*/, const RemoteChannelHandle& handle) {
731 return handle.value();
732 }
733
GetFileHandle(Message * message,FileReference ref) const734 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
735 LocalHandle handle;
736 auto* state = static_cast<MessageState*>(message->GetState());
737 state->GetLocalFileHandle(ref, &handle);
738 return handle;
739 }
740
GetChannelHandle(Message * message,ChannelReference ref) const741 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
742 ChannelReference ref) const {
743 LocalChannelHandle handle;
744 auto* state = static_cast<MessageState*>(message->GetState());
745 state->GetLocalChannelHandle(ref, &handle);
746 return handle;
747 }
748
Cancel()749 Status<void> Endpoint::Cancel() {
750 if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
751 return ErrorStatus{errno};
752 return {};
753 }
754
Create(const std::string & endpoint_path,mode_t,bool blocking)755 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
756 mode_t /*unused_mode*/,
757 bool blocking) {
758 return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
759 }
760
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)761 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
762 const std::string& endpoint_path, bool blocking) {
763 return std::unique_ptr<Endpoint>(
764 new Endpoint(endpoint_path, blocking, false));
765 }
766
CreateFromSocketFd(LocalHandle socket_fd)767 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
768 return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
769 }
770
RegisterNewChannelForTests(LocalHandle channel_fd)771 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
772 int optval = 1;
773 if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
774 sizeof(optval)) == -1) {
775 ALOGE(
776 "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
777 "of the credentials for channel %d: %s",
778 channel_fd.Get(), strerror(errno));
779 return ErrorStatus(errno);
780 }
781 return OnNewChannel(std::move(channel_fd));
782 }
783
784 } // namespace uds
785 } // namespace pdx
786 } // namespace android
787