1 #include "uds/service_endpoint.h"
2 
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm>  // std::min
9 
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18 
19 namespace {
20 
21 constexpr int kMaxBackLogForSocketListen = 1;
22 
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33 
34 struct MessageState {
GetLocalFileHandle__anon9a8a71d40111::MessageState35   bool GetLocalFileHandle(int index, LocalHandle* handle) {
36     if (index < 0) {
37       handle->Reset(index);
38     } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39       *handle = std::move(request.file_descriptors[index]);
40     } else {
41       return false;
42     }
43     return true;
44   }
45 
GetLocalChannelHandle__anon9a8a71d40111::MessageState46   bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47     if (index < 0) {
48       *handle = LocalChannelHandle{nullptr, index};
49     } else if (static_cast<size_t>(index) < request.channels.size()) {
50       auto& channel_info = request.channels[index];
51       *handle = ChannelManager::Get().CreateHandle(
52           std::move(channel_info.data_fd),
53           std::move(channel_info.pollin_event_fd),
54           std::move(channel_info.pollhup_event_fd));
55     } else {
56       return false;
57     }
58     return true;
59   }
60 
PushFileHandle__anon9a8a71d40111::MessageState61   Status<FileReference> PushFileHandle(BorrowedHandle handle) {
62     if (!handle)
63       return handle.Get();
64     response.file_descriptors.push_back(std::move(handle));
65     return response.file_descriptors.size() - 1;
66   }
67 
PushChannelHandle__anon9a8a71d40111::MessageState68   Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
69     if (!handle)
70       return handle.value();
71 
72     if (auto* channel_data =
73             ChannelManager::Get().GetChannelData(handle.value())) {
74       ChannelInfo<BorrowedHandle> channel_info{
75           channel_data->data_fd(), channel_data->pollin_event_fd(),
76           channel_data->pollhup_event_fd()};
77       response.channels.push_back(std::move(channel_info));
78       return response.channels.size() - 1;
79     } else {
80       return ErrorStatus{EINVAL};
81     }
82   }
83 
PushChannelHandle__anon9a8a71d40111::MessageState84   Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
85                                              BorrowedHandle pollin_event_fd,
86                                              BorrowedHandle pollhup_event_fd) {
87     if (!data_fd || !pollin_event_fd || !pollhup_event_fd)
88       return ErrorStatus{EINVAL};
89     ChannelInfo<BorrowedHandle> channel_info{std::move(data_fd),
90                                              std::move(pollin_event_fd),
91                                              std::move(pollhup_event_fd)};
92     response.channels.push_back(std::move(channel_info));
93     return response.channels.size() - 1;
94   }
95 
WriteData__anon9a8a71d40111::MessageState96   Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
97     size_t size = 0;
98     for (size_t i = 0; i < vector_length; i++) {
99       const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
100       response_data.insert(response_data.end(), data, data + vector[i].iov_len);
101       size += vector[i].iov_len;
102     }
103     return size;
104   }
105 
ReadData__anon9a8a71d40111::MessageState106   Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
107     size_t size_remaining = request_data.size() - request_data_read_pos;
108     size_t size = 0;
109     for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
110       size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
111       memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
112              size_to_copy);
113       size += size_to_copy;
114       request_data_read_pos += size_to_copy;
115       size_remaining -= size_to_copy;
116     }
117     return size;
118   }
119 
120   android::pdx::uds::RequestHeader<LocalHandle> request;
121   android::pdx::uds::ResponseHeader<BorrowedHandle> response;
122   std::vector<LocalHandle> sockets_to_close;
123   std::vector<uint8_t> request_data;
124   size_t request_data_read_pos{0};
125   std::vector<uint8_t> response_data;
126 };
127 
128 }  // anonymous namespace
129 
130 namespace android {
131 namespace pdx {
132 namespace uds {
133 
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)134 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
135                    bool use_init_socket_fd)
136     : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
137       is_blocking_{blocking} {
138   LocalHandle fd;
139   if (use_init_socket_fd) {
140     // Cut off the /dev/socket/ prefix from the full socket path and use the
141     // resulting "name" to retrieve the file descriptor for the socket created
142     // by the init process.
143     constexpr char prefix[] = "/dev/socket/";
144     CHECK(android::base::StartsWith(endpoint_path_, prefix))
145         << "Endpoint::Endpoint: Socket name '" << endpoint_path_
146         << "' must begin with '" << prefix << "'";
147     std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
148     fd.Reset(android_get_control_socket(socket_name.c_str()));
149     CHECK(fd.IsValid())
150         << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
151         << socket_name << "'";
152     fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
153   } else {
154     fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
155     CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
156                         << strerror(errno);
157 
158     sockaddr_un local;
159     local.sun_family = AF_UNIX;
160     strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
161     local.sun_path[sizeof(local.sun_path) - 1] = '\0';
162 
163     unlink(local.sun_path);
164     int ret =
165         bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
166     CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
167   }
168   Init(std::move(fd));
169 }
170 
Endpoint(LocalHandle socket_fd)171 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
172 
Init(LocalHandle socket_fd)173 void Endpoint::Init(LocalHandle socket_fd) {
174   if (socket_fd) {
175     CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
176         << "Endpoint::Endpoint: listen error: " << strerror(errno);
177   }
178   cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
179   CHECK(cancel_event_fd_.IsValid())
180       << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
181 
182   epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
183   CHECK(epoll_fd_.IsValid())
184       << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
185 
186   if (socket_fd) {
187     epoll_event socket_event;
188     socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
189     socket_event.data.fd = socket_fd.Get();
190     int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
191                         &socket_event);
192     CHECK_EQ(ret, 0)
193         << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
194         << strerror(errno);
195   }
196 
197   epoll_event cancel_event;
198   cancel_event.events = EPOLLIN;
199   cancel_event.data.fd = cancel_event_fd_.Get();
200 
201   int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
202                       &cancel_event);
203   CHECK_EQ(ret, 0)
204       << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
205       << strerror(errno);
206   socket_fd_ = std::move(socket_fd);
207 }
208 
AllocateMessageState()209 void* Endpoint::AllocateMessageState() { return new MessageState; }
210 
FreeMessageState(void * state)211 void Endpoint::FreeMessageState(void* state) {
212   delete static_cast<MessageState*>(state);
213 }
214 
AcceptConnection(Message * message)215 Status<void> Endpoint::AcceptConnection(Message* message) {
216   if (!socket_fd_)
217     return ErrorStatus(EBADF);
218 
219   sockaddr_un remote;
220   socklen_t addrlen = sizeof(remote);
221   LocalHandle connection_fd{accept4(socket_fd_.Get(),
222                                     reinterpret_cast<sockaddr*>(&remote),
223                                     &addrlen, SOCK_CLOEXEC)};
224   if (!connection_fd) {
225     ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
226           strerror(errno));
227     return ErrorStatus(errno);
228   }
229 
230   LocalHandle local_socket;
231   LocalHandle remote_socket;
232   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
233   if (!status)
234     return status;
235 
236   // Borrow the local channel handle before we move it into OnNewChannel().
237   BorrowedHandle channel_handle = local_socket.Borrow();
238   status = OnNewChannel(std::move(local_socket));
239   if (!status)
240     return status;
241 
242   // Send the channel socket fd to the client.
243   ChannelConnectionInfo<LocalHandle> connection_info;
244   connection_info.channel_fd = std::move(remote_socket);
245   status = SendData(connection_fd.Borrow(), connection_info);
246 
247   if (status) {
248     // Get the CHANNEL_OPEN message from client over the channel socket.
249     status = ReceiveMessageForChannel(channel_handle, message);
250   } else {
251     CloseChannel(GetChannelId(channel_handle));
252   }
253 
254   // Don't need the connection socket anymore. Further communication should
255   // happen over the channel socket.
256   shutdown(connection_fd.Get(), SHUT_WR);
257   return status;
258 }
259 
SetService(Service * service)260 Status<void> Endpoint::SetService(Service* service) {
261   service_ = service;
262   return {};
263 }
264 
SetChannel(int channel_id,Channel * channel)265 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
266   std::lock_guard<std::mutex> autolock(channel_mutex_);
267   auto channel_data = channels_.find(channel_id);
268   if (channel_data == channels_.end())
269     return ErrorStatus{EINVAL};
270   channel_data->second.channel_state = channel;
271   return {};
272 }
273 
OnNewChannel(LocalHandle channel_fd)274 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
275   std::lock_guard<std::mutex> autolock(channel_mutex_);
276   Status<void> status;
277   status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
278   return status;
279 }
280 
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)281 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
282     LocalHandle channel_fd, Channel* channel_state) {
283   epoll_event event;
284   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
285   event.data.fd = channel_fd.Get();
286   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
287     ALOGE(
288         "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
289         strerror(errno));
290     return ErrorStatus(errno);
291   }
292   ChannelData channel_data;
293   channel_data.data_fd = std::move(channel_fd);
294   channel_data.channel_state = channel_state;
295   for (;;) {
296     // Try new channel IDs until we find one which is not already in the map.
297     if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
298       last_channel_id_ = 1;
299     auto iter = channels_.lower_bound(last_channel_id_);
300     if (iter == channels_.end() || iter->first != last_channel_id_) {
301       channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
302       iter = channels_.emplace_hint(iter, last_channel_id_,
303                                     std::move(channel_data));
304       return std::make_pair(last_channel_id_, &iter->second);
305     }
306   }
307 }
308 
ReenableEpollEvent(const BorrowedHandle & fd)309 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
310   epoll_event event;
311   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
312   event.data.fd = fd.Get();
313   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
314     ALOGE(
315         "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
316         "endpoint: %s\n",
317         strerror(errno));
318     return ErrorStatus(errno);
319   }
320   return {};
321 }
322 
CloseChannel(int channel_id)323 Status<void> Endpoint::CloseChannel(int channel_id) {
324   std::lock_guard<std::mutex> autolock(channel_mutex_);
325   return CloseChannelLocked(channel_id);
326 }
327 
CloseChannelLocked(int32_t channel_id)328 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
329   ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
330 
331   auto iter = channels_.find(channel_id);
332   if (iter == channels_.end())
333     return ErrorStatus{EINVAL};
334 
335   int channel_fd = iter->second.data_fd.Get();
336   Status<void> status;
337   epoll_event dummy;  // See BUGS in man 2 epoll_ctl.
338   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) {
339     status.SetError(errno);
340     ALOGE(
341         "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
342         "%s\n",
343         strerror(errno));
344   } else {
345     status.SetValue();
346   }
347 
348   channel_fd_to_id_.erase(channel_fd);
349   channels_.erase(iter);
350   return status;
351 }
352 
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)353 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
354                                            int set_mask) {
355   std::lock_guard<std::mutex> autolock(channel_mutex_);
356 
357   auto search = channels_.find(channel_id);
358   if (search != channels_.end()) {
359     auto& channel_data = search->second;
360     channel_data.event_set.ModifyEvents(clear_mask, set_mask);
361     return {};
362   }
363 
364   return ErrorStatus{EINVAL};
365 }
366 
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)367 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
368                                                LocalHandle* remote_socket) {
369   Status<void> status;
370   char* endpoint_context = nullptr;
371   // Make sure the channel socket has the correct SELinux label applied.
372   // Here we get the label from the endpoint file descriptor, which should be
373   // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
374   // "endpoint" with "channel" to produce the channel label such as this:
375   // "u:object_r:pdx_service_channel_socket:s0".
376   if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
377     std::string channel_context = endpoint_context;
378     freecon(endpoint_context);
379     const std::string suffix = "_endpoint_socket";
380     auto pos = channel_context.find(suffix);
381     if (pos != std::string::npos) {
382       channel_context.replace(pos, suffix.size(), "_channel_socket");
383     } else {
384       ALOGW(
385           "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
386           "does not contain expected substring '%s'",
387           channel_context.c_str(), suffix.c_str());
388     }
389     ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
390              "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
391              "security context: %s",
392              strerror(errno));
393   } else {
394     ALOGE(
395         "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
396         "socket's security context: %s",
397         strerror(errno));
398   }
399 
400   int channel_pair[2] = {};
401   if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
402     ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
403           strerror(errno));
404     status.SetError(errno);
405     return status;
406   }
407 
408   setsockcreatecon_raw(nullptr);
409 
410   local_socket->Reset(channel_pair[0]);
411   remote_socket->Reset(channel_pair[1]);
412 
413   int optval = 1;
414   if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
415                  sizeof(optval)) == -1) {
416     ALOGE(
417         "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
418         "the credentials for channel %d: %s",
419         local_socket->Get(), strerror(errno));
420     status.SetError(errno);
421   }
422   return status;
423 }
424 
PushChannel(Message * message,int,Channel * channel,int * channel_id)425 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
426                                                   int /*flags*/,
427                                                   Channel* channel,
428                                                   int* channel_id) {
429   LocalHandle local_socket;
430   LocalHandle remote_socket;
431   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
432   if (!status)
433     return status.error_status();
434 
435   std::lock_guard<std::mutex> autolock(channel_mutex_);
436   auto channel_data_status =
437       OnNewChannelLocked(std::move(local_socket), channel);
438   if (!channel_data_status)
439     return channel_data_status.error_status();
440 
441   ChannelData* channel_data;
442   std::tie(*channel_id, channel_data) = channel_data_status.take();
443 
444   // Flags are ignored for now.
445   // TODO(xiaohuit): Implement those.
446 
447   auto* state = static_cast<MessageState*>(message->GetState());
448   Status<ChannelReference> ref = state->PushChannelHandle(
449       remote_socket.Borrow(), channel_data->event_set.pollin_event_fd(),
450       channel_data->event_set.pollhup_event_fd());
451   if (!ref)
452     return ref.error_status();
453   state->sockets_to_close.push_back(std::move(remote_socket));
454   return RemoteChannelHandle{ref.get()};
455 }
456 
CheckChannel(const Message *,ChannelReference,Channel **)457 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
458                                    ChannelReference /*ref*/,
459                                    Channel** /*channel*/) {
460   // TODO(xiaohuit): Implement this.
461   return ErrorStatus(EFAULT);
462 }
463 
GetChannelState(int32_t channel_id)464 Channel* Endpoint::GetChannelState(int32_t channel_id) {
465   std::lock_guard<std::mutex> autolock(channel_mutex_);
466   auto channel_data = channels_.find(channel_id);
467   return (channel_data != channels_.end()) ? channel_data->second.channel_state
468                                            : nullptr;
469 }
470 
GetChannelSocketFd(int32_t channel_id)471 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
472   std::lock_guard<std::mutex> autolock(channel_mutex_);
473   BorrowedHandle handle;
474   auto channel_data = channels_.find(channel_id);
475   if (channel_data != channels_.end())
476     handle = channel_data->second.data_fd.Borrow();
477   return handle;
478 }
479 
GetChannelEventFd(int32_t channel_id)480 Status<std::pair<BorrowedHandle, BorrowedHandle>> Endpoint::GetChannelEventFd(
481     int32_t channel_id) {
482   std::lock_guard<std::mutex> autolock(channel_mutex_);
483   auto channel_data = channels_.find(channel_id);
484   if (channel_data != channels_.end()) {
485     return {{channel_data->second.event_set.pollin_event_fd(),
486              channel_data->second.event_set.pollhup_event_fd()}};
487   }
488   return ErrorStatus(ENOENT);
489 }
490 
GetChannelId(const BorrowedHandle & channel_fd)491 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
492   std::lock_guard<std::mutex> autolock(channel_mutex_);
493   auto iter = channel_fd_to_id_.find(channel_fd.Get());
494   return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
495 }
496 
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)497 Status<void> Endpoint::ReceiveMessageForChannel(
498     const BorrowedHandle& channel_fd, Message* message) {
499   RequestHeader<LocalHandle> request;
500   int32_t channel_id = GetChannelId(channel_fd);
501   auto status = ReceiveData(channel_fd.Borrow(), &request);
502   if (!status) {
503     if (status.error() == ESHUTDOWN) {
504       BuildCloseMessage(channel_id, message);
505       return {};
506     } else {
507       CloseChannel(channel_id);
508       return status;
509     }
510   }
511 
512   MessageInfo info;
513   info.pid = request.cred.pid;
514   info.tid = -1;
515   info.cid = channel_id;
516   info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
517                                 : GetNextAvailableMessageId();
518   info.euid = request.cred.uid;
519   info.egid = request.cred.gid;
520   info.op = request.op;
521   info.flags = 0;
522   info.service = service_;
523   info.channel = GetChannelState(channel_id);
524   info.send_len = request.send_len;
525   info.recv_len = request.max_recv_len;
526   info.fd_count = request.file_descriptors.size();
527   static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
528                 "Impulse payload sizes must be the same in RequestHeader and "
529                 "MessageInfo");
530   memcpy(info.impulse, request.impulse_payload.data(),
531          request.impulse_payload.size());
532   *message = Message{info};
533   auto* state = static_cast<MessageState*>(message->GetState());
534   state->request = std::move(request);
535   if (request.send_len > 0 && !request.is_impulse) {
536     state->request_data.resize(request.send_len);
537     status = ReceiveData(channel_fd, state->request_data.data(),
538                          state->request_data.size());
539   }
540 
541   if (status && request.is_impulse)
542     status = ReenableEpollEvent(channel_fd);
543 
544   if (!status) {
545     if (status.error() == ESHUTDOWN) {
546       BuildCloseMessage(channel_id, message);
547       return {};
548     } else {
549       CloseChannel(channel_id);
550       return status;
551     }
552   }
553 
554   return status;
555 }
556 
BuildCloseMessage(int32_t channel_id,Message * message)557 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
558   ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
559   MessageInfo info;
560   info.pid = -1;
561   info.tid = -1;
562   info.cid = channel_id;
563   info.mid = GetNextAvailableMessageId();
564   info.euid = -1;
565   info.egid = -1;
566   info.op = opcodes::CHANNEL_CLOSE;
567   info.flags = 0;
568   info.service = service_;
569   info.channel = GetChannelState(channel_id);
570   info.send_len = 0;
571   info.recv_len = 0;
572   info.fd_count = 0;
573   *message = Message{info};
574 }
575 
MessageReceive(Message * message)576 Status<void> Endpoint::MessageReceive(Message* message) {
577   // Receive at most one event from the epoll set. This should prevent multiple
578   // dispatch threads from attempting to handle messages on the same socket at
579   // the same time.
580   epoll_event event;
581   int count = RETRY_EINTR(
582       epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
583   if (count < 0) {
584     ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
585           strerror(errno));
586     return ErrorStatus{errno};
587   } else if (count == 0) {
588     return ErrorStatus{ETIMEDOUT};
589   }
590 
591   if (event.data.fd == cancel_event_fd_.Get()) {
592     return ErrorStatus{ESHUTDOWN};
593   }
594 
595   if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
596     auto status = AcceptConnection(message);
597     auto reenable_status = ReenableEpollEvent(socket_fd_.Borrow());
598     if (!reenable_status)
599       return reenable_status;
600     return status;
601   }
602 
603   BorrowedHandle channel_fd{event.data.fd};
604   return ReceiveMessageForChannel(channel_fd, message);
605 }
606 
MessageReply(Message * message,int return_code)607 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
608   const int32_t channel_id = message->GetChannelId();
609   auto channel_socket = GetChannelSocketFd(channel_id);
610   if (!channel_socket)
611     return ErrorStatus{EBADF};
612 
613   auto* state = static_cast<MessageState*>(message->GetState());
614   switch (message->GetOp()) {
615     case opcodes::CHANNEL_CLOSE:
616       return CloseChannel(channel_id);
617 
618     case opcodes::CHANNEL_OPEN:
619       if (return_code < 0) {
620         return CloseChannel(channel_id);
621       } else {
622         // Open messages do not have a payload and may not transfer any channels
623         // or file descriptors on behalf of the service.
624         state->response_data.clear();
625         state->response.file_descriptors.clear();
626         state->response.channels.clear();
627 
628         // Return the channel event-related fds in a single ChannelInfo entry
629         // with an empty data_fd member.
630         auto status = GetChannelEventFd(channel_id);
631         if (!status)
632           return status.error_status();
633 
634         auto handles = status.take();
635         state->response.channels.push_back({BorrowedHandle(),
636                                             std::move(handles.first),
637                                             std::move(handles.second)});
638         return_code = 0;
639       }
640       break;
641   }
642 
643   state->response.ret_code = return_code;
644   state->response.recv_len = state->response_data.size();
645   auto status = SendData(channel_socket, state->response);
646   if (status && !state->response_data.empty()) {
647     status = SendData(channel_socket, state->response_data.data(),
648                       state->response_data.size());
649   }
650 
651   if (status)
652     status = ReenableEpollEvent(channel_socket);
653 
654   return status;
655 }
656 
MessageReplyFd(Message * message,unsigned int push_fd)657 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
658   auto* state = static_cast<MessageState*>(message->GetState());
659   auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
660   if (!ref)
661     return ref.error_status();
662   return MessageReply(message, ref.get());
663 }
664 
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)665 Status<void> Endpoint::MessageReplyChannelHandle(
666     Message* message, const LocalChannelHandle& handle) {
667   auto* state = static_cast<MessageState*>(message->GetState());
668   auto ref = state->PushChannelHandle(handle.Borrow());
669   if (!ref)
670     return ref.error_status();
671   return MessageReply(message, ref.get());
672 }
673 
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)674 Status<void> Endpoint::MessageReplyChannelHandle(
675     Message* message, const BorrowedChannelHandle& handle) {
676   auto* state = static_cast<MessageState*>(message->GetState());
677   auto ref = state->PushChannelHandle(handle.Duplicate());
678   if (!ref)
679     return ref.error_status();
680   return MessageReply(message, ref.get());
681 }
682 
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)683 Status<void> Endpoint::MessageReplyChannelHandle(
684     Message* message, const RemoteChannelHandle& handle) {
685   return MessageReply(message, handle.value());
686 }
687 
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)688 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
689                                          size_t vector_length) {
690   auto* state = static_cast<MessageState*>(message->GetState());
691   return state->ReadData(vector, vector_length);
692 }
693 
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)694 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
695                                           size_t vector_length) {
696   auto* state = static_cast<MessageState*>(message->GetState());
697   return state->WriteData(vector, vector_length);
698 }
699 
PushFileHandle(Message * message,const LocalHandle & handle)700 Status<FileReference> Endpoint::PushFileHandle(Message* message,
701                                                const LocalHandle& handle) {
702   auto* state = static_cast<MessageState*>(message->GetState());
703   return state->PushFileHandle(handle.Borrow());
704 }
705 
PushFileHandle(Message * message,const BorrowedHandle & handle)706 Status<FileReference> Endpoint::PushFileHandle(Message* message,
707                                                const BorrowedHandle& handle) {
708   auto* state = static_cast<MessageState*>(message->GetState());
709   return state->PushFileHandle(handle.Duplicate());
710 }
711 
PushFileHandle(Message *,const RemoteHandle & handle)712 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
713                                                const RemoteHandle& handle) {
714   return handle.Get();
715 }
716 
PushChannelHandle(Message * message,const LocalChannelHandle & handle)717 Status<ChannelReference> Endpoint::PushChannelHandle(
718     Message* message, const LocalChannelHandle& handle) {
719   auto* state = static_cast<MessageState*>(message->GetState());
720   return state->PushChannelHandle(handle.Borrow());
721 }
722 
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)723 Status<ChannelReference> Endpoint::PushChannelHandle(
724     Message* message, const BorrowedChannelHandle& handle) {
725   auto* state = static_cast<MessageState*>(message->GetState());
726   return state->PushChannelHandle(handle.Duplicate());
727 }
728 
PushChannelHandle(Message *,const RemoteChannelHandle & handle)729 Status<ChannelReference> Endpoint::PushChannelHandle(
730     Message* /*message*/, const RemoteChannelHandle& handle) {
731   return handle.value();
732 }
733 
GetFileHandle(Message * message,FileReference ref) const734 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
735   LocalHandle handle;
736   auto* state = static_cast<MessageState*>(message->GetState());
737   state->GetLocalFileHandle(ref, &handle);
738   return handle;
739 }
740 
GetChannelHandle(Message * message,ChannelReference ref) const741 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
742                                               ChannelReference ref) const {
743   LocalChannelHandle handle;
744   auto* state = static_cast<MessageState*>(message->GetState());
745   state->GetLocalChannelHandle(ref, &handle);
746   return handle;
747 }
748 
Cancel()749 Status<void> Endpoint::Cancel() {
750   if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
751     return ErrorStatus{errno};
752   return {};
753 }
754 
Create(const std::string & endpoint_path,mode_t,bool blocking)755 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
756                                            mode_t /*unused_mode*/,
757                                            bool blocking) {
758   return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
759 }
760 
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)761 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
762     const std::string& endpoint_path, bool blocking) {
763   return std::unique_ptr<Endpoint>(
764       new Endpoint(endpoint_path, blocking, false));
765 }
766 
CreateFromSocketFd(LocalHandle socket_fd)767 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
768   return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
769 }
770 
RegisterNewChannelForTests(LocalHandle channel_fd)771 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
772   int optval = 1;
773   if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
774                  sizeof(optval)) == -1) {
775     ALOGE(
776         "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
777         "of the credentials for channel %d: %s",
778         channel_fd.Get(), strerror(errno));
779     return ErrorStatus(errno);
780   }
781   return OnNewChannel(std::move(channel_fd));
782 }
783 
784 }  // namespace uds
785 }  // namespace pdx
786 }  // namespace android
787