1 #include "uds/service_endpoint.h"
2 
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm>  // std::min
9 
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18 
19 namespace {
20 
21 constexpr int kMaxBackLogForSocketListen = 1;
22 
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33 
34 struct MessageState {
GetLocalFileHandle__anon9a8a71d40111::MessageState35   bool GetLocalFileHandle(int index, LocalHandle* handle) {
36     if (index < 0) {
37       handle->Reset(index);
38     } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39       *handle = std::move(request.file_descriptors[index]);
40     } else {
41       return false;
42     }
43     return true;
44   }
45 
GetLocalChannelHandle__anon9a8a71d40111::MessageState46   bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47     if (index < 0) {
48       *handle = LocalChannelHandle{nullptr, index};
49     } else if (static_cast<size_t>(index) < request.channels.size()) {
50       auto& channel_info = request.channels[index];
51       *handle = ChannelManager::Get().CreateHandle(
52           std::move(channel_info.data_fd), std::move(channel_info.event_fd));
53     } else {
54       return false;
55     }
56     return true;
57   }
58 
PushFileHandle__anon9a8a71d40111::MessageState59   Status<FileReference> PushFileHandle(BorrowedHandle handle) {
60     if (!handle)
61       return handle.Get();
62     response.file_descriptors.push_back(std::move(handle));
63     return response.file_descriptors.size() - 1;
64   }
65 
PushChannelHandle__anon9a8a71d40111::MessageState66   Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
67     if (!handle)
68       return handle.value();
69 
70     if (auto* channel_data =
71             ChannelManager::Get().GetChannelData(handle.value())) {
72       ChannelInfo<BorrowedHandle> channel_info;
73       channel_info.data_fd.Reset(handle.value());
74       channel_info.event_fd = channel_data->event_receiver.event_fd();
75       response.channels.push_back(std::move(channel_info));
76       return response.channels.size() - 1;
77     } else {
78       return ErrorStatus{EINVAL};
79     }
80   }
81 
PushChannelHandle__anon9a8a71d40111::MessageState82   Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
83                                              BorrowedHandle event_fd) {
84     if (!data_fd || !event_fd)
85       return ErrorStatus{EINVAL};
86     ChannelInfo<BorrowedHandle> channel_info;
87     channel_info.data_fd = std::move(data_fd);
88     channel_info.event_fd = std::move(event_fd);
89     response.channels.push_back(std::move(channel_info));
90     return response.channels.size() - 1;
91   }
92 
WriteData__anon9a8a71d40111::MessageState93   Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
94     size_t size = 0;
95     for (size_t i = 0; i < vector_length; i++) {
96       const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
97       response_data.insert(response_data.end(), data, data + vector[i].iov_len);
98       size += vector[i].iov_len;
99     }
100     return size;
101   }
102 
ReadData__anon9a8a71d40111::MessageState103   Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
104     size_t size_remaining = request_data.size() - request_data_read_pos;
105     size_t size = 0;
106     for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
107       size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
108       memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
109              size_to_copy);
110       size += size_to_copy;
111       request_data_read_pos += size_to_copy;
112       size_remaining -= size_to_copy;
113     }
114     return size;
115   }
116 
117   android::pdx::uds::RequestHeader<LocalHandle> request;
118   android::pdx::uds::ResponseHeader<BorrowedHandle> response;
119   std::vector<LocalHandle> sockets_to_close;
120   std::vector<uint8_t> request_data;
121   size_t request_data_read_pos{0};
122   std::vector<uint8_t> response_data;
123 };
124 
125 }  // anonymous namespace
126 
127 namespace android {
128 namespace pdx {
129 namespace uds {
130 
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)131 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
132                    bool use_init_socket_fd)
133     : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
134       is_blocking_{blocking} {
135   LocalHandle fd;
136   if (use_init_socket_fd) {
137     // Cut off the /dev/socket/ prefix from the full socket path and use the
138     // resulting "name" to retrieve the file descriptor for the socket created
139     // by the init process.
140     constexpr char prefix[] = "/dev/socket/";
141     CHECK(android::base::StartsWith(endpoint_path_, prefix))
142         << "Endpoint::Endpoint: Socket name '" << endpoint_path_
143         << "' must begin with '" << prefix << "'";
144     std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
145     fd.Reset(android_get_control_socket(socket_name.c_str()));
146     CHECK(fd.IsValid())
147         << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
148         << socket_name << "'";
149     fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
150   } else {
151     fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
152     CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
153                         << strerror(errno);
154 
155     sockaddr_un local;
156     local.sun_family = AF_UNIX;
157     strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
158     local.sun_path[sizeof(local.sun_path) - 1] = '\0';
159 
160     unlink(local.sun_path);
161     int ret =
162         bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
163     CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
164   }
165   Init(std::move(fd));
166 }
167 
Endpoint(LocalHandle socket_fd)168 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
169 
Init(LocalHandle socket_fd)170 void Endpoint::Init(LocalHandle socket_fd) {
171   if (socket_fd) {
172     CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
173         << "Endpoint::Endpoint: listen error: " << strerror(errno);
174   }
175   cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
176   CHECK(cancel_event_fd_.IsValid())
177       << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
178 
179   epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
180   CHECK(epoll_fd_.IsValid())
181       << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
182 
183   if (socket_fd) {
184     epoll_event socket_event;
185     socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
186     socket_event.data.fd = socket_fd.Get();
187     int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
188                         &socket_event);
189     CHECK_EQ(ret, 0)
190         << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
191         << strerror(errno);
192   }
193 
194   epoll_event cancel_event;
195   cancel_event.events = EPOLLIN;
196   cancel_event.data.fd = cancel_event_fd_.Get();
197 
198   int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
199                       &cancel_event);
200   CHECK_EQ(ret, 0)
201       << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
202       << strerror(errno);
203   socket_fd_ = std::move(socket_fd);
204 }
205 
AllocateMessageState()206 void* Endpoint::AllocateMessageState() { return new MessageState; }
207 
FreeMessageState(void * state)208 void Endpoint::FreeMessageState(void* state) {
209   delete static_cast<MessageState*>(state);
210 }
211 
AcceptConnection(Message * message)212 Status<void> Endpoint::AcceptConnection(Message* message) {
213   if (!socket_fd_)
214     return ErrorStatus(EBADF);
215 
216   sockaddr_un remote;
217   socklen_t addrlen = sizeof(remote);
218   LocalHandle connection_fd{accept4(socket_fd_.Get(),
219                                     reinterpret_cast<sockaddr*>(&remote),
220                                     &addrlen, SOCK_CLOEXEC)};
221   if (!connection_fd) {
222     ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
223           strerror(errno));
224     return ErrorStatus(errno);
225   }
226 
227   LocalHandle local_socket;
228   LocalHandle remote_socket;
229   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
230   if (!status)
231     return status;
232 
233   // Borrow the local channel handle before we move it into OnNewChannel().
234   BorrowedHandle channel_handle = local_socket.Borrow();
235   status = OnNewChannel(std::move(local_socket));
236   if (!status)
237     return status;
238 
239   // Send the channel socket fd to the client.
240   ChannelConnectionInfo<LocalHandle> connection_info;
241   connection_info.channel_fd = std::move(remote_socket);
242   status = SendData(connection_fd.Borrow(), connection_info);
243 
244   if (status) {
245     // Get the CHANNEL_OPEN message from client over the channel socket.
246     status = ReceiveMessageForChannel(channel_handle, message);
247   } else {
248     CloseChannel(GetChannelId(channel_handle));
249   }
250 
251   // Don't need the connection socket anymore. Further communication should
252   // happen over the channel socket.
253   shutdown(connection_fd.Get(), SHUT_WR);
254   return status;
255 }
256 
SetService(Service * service)257 Status<void> Endpoint::SetService(Service* service) {
258   service_ = service;
259   return {};
260 }
261 
SetChannel(int channel_id,Channel * channel)262 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
263   std::lock_guard<std::mutex> autolock(channel_mutex_);
264   auto channel_data = channels_.find(channel_id);
265   if (channel_data == channels_.end())
266     return ErrorStatus{EINVAL};
267   channel_data->second.channel_state = channel;
268   return {};
269 }
270 
OnNewChannel(LocalHandle channel_fd)271 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
272   std::lock_guard<std::mutex> autolock(channel_mutex_);
273   Status<void> status;
274   status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
275   return status;
276 }
277 
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)278 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
279     LocalHandle channel_fd, Channel* channel_state) {
280   epoll_event event;
281   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
282   event.data.fd = channel_fd.Get();
283   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
284     ALOGE(
285         "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
286         strerror(errno));
287     return ErrorStatus(errno);
288   }
289   ChannelData channel_data;
290   channel_data.event_set.AddDataFd(channel_fd);
291   channel_data.data_fd = std::move(channel_fd);
292   channel_data.channel_state = channel_state;
293   for (;;) {
294     // Try new channel IDs until we find one which is not already in the map.
295     if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
296       last_channel_id_ = 1;
297     auto iter = channels_.lower_bound(last_channel_id_);
298     if (iter == channels_.end() || iter->first != last_channel_id_) {
299       channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
300       iter = channels_.emplace_hint(iter, last_channel_id_,
301                                     std::move(channel_data));
302       return std::make_pair(last_channel_id_, &iter->second);
303     }
304   }
305 }
306 
ReenableEpollEvent(const BorrowedHandle & fd)307 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
308   epoll_event event;
309   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
310   event.data.fd = fd.Get();
311   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
312     ALOGE(
313         "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
314         "endpoint: %s\n",
315         strerror(errno));
316     return ErrorStatus(errno);
317   }
318   return {};
319 }
320 
CloseChannel(int channel_id)321 Status<void> Endpoint::CloseChannel(int channel_id) {
322   std::lock_guard<std::mutex> autolock(channel_mutex_);
323   return CloseChannelLocked(channel_id);
324 }
325 
CloseChannelLocked(int32_t channel_id)326 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
327   ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
328 
329   auto iter = channels_.find(channel_id);
330   if (iter == channels_.end())
331     return ErrorStatus{EINVAL};
332 
333   int channel_fd = iter->second.data_fd.Get();
334   Status<void> status;
335   epoll_event dummy;  // See BUGS in man 2 epoll_ctl.
336   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) {
337     status.SetError(errno);
338     ALOGE(
339         "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
340         "%s\n",
341         strerror(errno));
342   } else {
343     status.SetValue();
344   }
345 
346   channel_fd_to_id_.erase(channel_fd);
347   channels_.erase(iter);
348   return status;
349 }
350 
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)351 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
352                                            int set_mask) {
353   std::lock_guard<std::mutex> autolock(channel_mutex_);
354 
355   auto search = channels_.find(channel_id);
356   if (search != channels_.end()) {
357     auto& channel_data = search->second;
358     channel_data.event_set.ModifyEvents(clear_mask, set_mask);
359     return {};
360   }
361 
362   return ErrorStatus{EINVAL};
363 }
364 
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)365 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
366                                                LocalHandle* remote_socket) {
367   Status<void> status;
368   char* endpoint_context = nullptr;
369   // Make sure the channel socket has the correct SELinux label applied.
370   // Here we get the label from the endpoint file descriptor, which should be
371   // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
372   // "endpoint" with "channel" to produce the channel label such as this:
373   // "u:object_r:pdx_service_channel_socket:s0".
374   if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
375     std::string channel_context = endpoint_context;
376     freecon(endpoint_context);
377     const std::string suffix = "_endpoint_socket";
378     auto pos = channel_context.find(suffix);
379     if (pos != std::string::npos) {
380       channel_context.replace(pos, suffix.size(), "_channel_socket");
381     } else {
382       ALOGW(
383           "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
384           "does not contain expected substring '%s'",
385           channel_context.c_str(), suffix.c_str());
386     }
387     ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
388              "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
389              "security context: %s",
390              strerror(errno));
391   } else {
392     ALOGE(
393         "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
394         "socket's security context: %s",
395         strerror(errno));
396   }
397 
398   int channel_pair[2] = {};
399   if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
400     ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
401           strerror(errno));
402     status.SetError(errno);
403     return status;
404   }
405 
406   setsockcreatecon_raw(nullptr);
407 
408   local_socket->Reset(channel_pair[0]);
409   remote_socket->Reset(channel_pair[1]);
410 
411   int optval = 1;
412   if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
413                  sizeof(optval)) == -1) {
414     ALOGE(
415         "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
416         "the credentials for channel %d: %s",
417         local_socket->Get(), strerror(errno));
418     status.SetError(errno);
419   }
420   return status;
421 }
422 
PushChannel(Message * message,int,Channel * channel,int * channel_id)423 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
424                                                   int /*flags*/,
425                                                   Channel* channel,
426                                                   int* channel_id) {
427   LocalHandle local_socket;
428   LocalHandle remote_socket;
429   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
430   if (!status)
431     return status.error_status();
432 
433   std::lock_guard<std::mutex> autolock(channel_mutex_);
434   auto channel_data = OnNewChannelLocked(std::move(local_socket), channel);
435   if (!channel_data)
436     return channel_data.error_status();
437   *channel_id = channel_data.get().first;
438 
439   // Flags are ignored for now.
440   // TODO(xiaohuit): Implement those.
441 
442   auto* state = static_cast<MessageState*>(message->GetState());
443   Status<ChannelReference> ref = state->PushChannelHandle(
444       remote_socket.Borrow(),
445       channel_data.get().second->event_set.event_fd().Borrow());
446   if (!ref)
447     return ref.error_status();
448   state->sockets_to_close.push_back(std::move(remote_socket));
449   return RemoteChannelHandle{ref.get()};
450 }
451 
CheckChannel(const Message *,ChannelReference,Channel **)452 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
453                                    ChannelReference /*ref*/,
454                                    Channel** /*channel*/) {
455   // TODO(xiaohuit): Implement this.
456   return ErrorStatus(EFAULT);
457 }
458 
GetChannelState(int32_t channel_id)459 Channel* Endpoint::GetChannelState(int32_t channel_id) {
460   std::lock_guard<std::mutex> autolock(channel_mutex_);
461   auto channel_data = channels_.find(channel_id);
462   return (channel_data != channels_.end()) ? channel_data->second.channel_state
463                                            : nullptr;
464 }
465 
GetChannelSocketFd(int32_t channel_id)466 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
467   std::lock_guard<std::mutex> autolock(channel_mutex_);
468   BorrowedHandle handle;
469   auto channel_data = channels_.find(channel_id);
470   if (channel_data != channels_.end())
471     handle = channel_data->second.data_fd.Borrow();
472   return handle;
473 }
474 
GetChannelEventFd(int32_t channel_id)475 BorrowedHandle Endpoint::GetChannelEventFd(int32_t channel_id) {
476   std::lock_guard<std::mutex> autolock(channel_mutex_);
477   BorrowedHandle handle;
478   auto channel_data = channels_.find(channel_id);
479   if (channel_data != channels_.end())
480     handle = channel_data->second.event_set.event_fd().Borrow();
481   return handle;
482 }
483 
GetChannelId(const BorrowedHandle & channel_fd)484 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
485   std::lock_guard<std::mutex> autolock(channel_mutex_);
486   auto iter = channel_fd_to_id_.find(channel_fd.Get());
487   return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
488 }
489 
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)490 Status<void> Endpoint::ReceiveMessageForChannel(
491     const BorrowedHandle& channel_fd, Message* message) {
492   RequestHeader<LocalHandle> request;
493   int32_t channel_id = GetChannelId(channel_fd);
494   auto status = ReceiveData(channel_fd.Borrow(), &request);
495   if (!status) {
496     if (status.error() == ESHUTDOWN) {
497       BuildCloseMessage(channel_id, message);
498       return {};
499     } else {
500       CloseChannel(channel_id);
501       return status;
502     }
503   }
504 
505   MessageInfo info;
506   info.pid = request.cred.pid;
507   info.tid = -1;
508   info.cid = channel_id;
509   info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
510                                 : GetNextAvailableMessageId();
511   info.euid = request.cred.uid;
512   info.egid = request.cred.gid;
513   info.op = request.op;
514   info.flags = 0;
515   info.service = service_;
516   info.channel = GetChannelState(channel_id);
517   info.send_len = request.send_len;
518   info.recv_len = request.max_recv_len;
519   info.fd_count = request.file_descriptors.size();
520   static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
521                 "Impulse payload sizes must be the same in RequestHeader and "
522                 "MessageInfo");
523   memcpy(info.impulse, request.impulse_payload.data(),
524          request.impulse_payload.size());
525   *message = Message{info};
526   auto* state = static_cast<MessageState*>(message->GetState());
527   state->request = std::move(request);
528   if (request.send_len > 0 && !request.is_impulse) {
529     state->request_data.resize(request.send_len);
530     status = ReceiveData(channel_fd, state->request_data.data(),
531                          state->request_data.size());
532   }
533 
534   if (status && request.is_impulse)
535     status = ReenableEpollEvent(channel_fd);
536 
537   if (!status) {
538     if (status.error() == ESHUTDOWN) {
539       BuildCloseMessage(channel_id, message);
540       return {};
541     } else {
542       CloseChannel(channel_id);
543       return status;
544     }
545   }
546 
547   return status;
548 }
549 
BuildCloseMessage(int32_t channel_id,Message * message)550 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
551   ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
552   MessageInfo info;
553   info.pid = -1;
554   info.tid = -1;
555   info.cid = channel_id;
556   info.mid = GetNextAvailableMessageId();
557   info.euid = -1;
558   info.egid = -1;
559   info.op = opcodes::CHANNEL_CLOSE;
560   info.flags = 0;
561   info.service = service_;
562   info.channel = GetChannelState(channel_id);
563   info.send_len = 0;
564   info.recv_len = 0;
565   info.fd_count = 0;
566   *message = Message{info};
567 }
568 
MessageReceive(Message * message)569 Status<void> Endpoint::MessageReceive(Message* message) {
570   // Receive at most one event from the epoll set. This should prevent multiple
571   // dispatch threads from attempting to handle messages on the same socket at
572   // the same time.
573   epoll_event event;
574   int count = RETRY_EINTR(
575       epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
576   if (count < 0) {
577     ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
578           strerror(errno));
579     return ErrorStatus{errno};
580   } else if (count == 0) {
581     return ErrorStatus{ETIMEDOUT};
582   }
583 
584   if (event.data.fd == cancel_event_fd_.Get()) {
585     return ErrorStatus{ESHUTDOWN};
586   }
587 
588   if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
589     auto status = AcceptConnection(message);
590     if (!status)
591       return status;
592     return ReenableEpollEvent(socket_fd_.Borrow());
593   }
594 
595   BorrowedHandle channel_fd{event.data.fd};
596   if (event.events & (EPOLLRDHUP | EPOLLHUP)) {
597     BuildCloseMessage(GetChannelId(channel_fd), message);
598     return {};
599   }
600 
601   return ReceiveMessageForChannel(channel_fd, message);
602 }
603 
MessageReply(Message * message,int return_code)604 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
605   const int32_t channel_id = message->GetChannelId();
606   auto channel_socket = GetChannelSocketFd(channel_id);
607   if (!channel_socket)
608     return ErrorStatus{EBADF};
609 
610   auto* state = static_cast<MessageState*>(message->GetState());
611   switch (message->GetOp()) {
612     case opcodes::CHANNEL_CLOSE:
613       return CloseChannel(channel_id);
614 
615     case opcodes::CHANNEL_OPEN:
616       if (return_code < 0) {
617         return CloseChannel(channel_id);
618       } else {
619         // Reply with the event fd.
620         auto push_status = state->PushFileHandle(GetChannelEventFd(channel_id));
621         state->response_data.clear();  // Just in case...
622         if (!push_status)
623           return push_status.error_status();
624         return_code = push_status.get();
625       }
626       break;
627   }
628 
629   state->response.ret_code = return_code;
630   state->response.recv_len = state->response_data.size();
631   auto status = SendData(channel_socket, state->response);
632   if (status && !state->response_data.empty()) {
633     status = SendData(channel_socket, state->response_data.data(),
634                       state->response_data.size());
635   }
636 
637   if (status)
638     status = ReenableEpollEvent(channel_socket);
639 
640   return status;
641 }
642 
MessageReplyFd(Message * message,unsigned int push_fd)643 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
644   auto* state = static_cast<MessageState*>(message->GetState());
645   auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
646   if (!ref)
647     return ref.error_status();
648   return MessageReply(message, ref.get());
649 }
650 
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)651 Status<void> Endpoint::MessageReplyChannelHandle(
652     Message* message, const LocalChannelHandle& handle) {
653   auto* state = static_cast<MessageState*>(message->GetState());
654   auto ref = state->PushChannelHandle(handle.Borrow());
655   if (!ref)
656     return ref.error_status();
657   return MessageReply(message, ref.get());
658 }
659 
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)660 Status<void> Endpoint::MessageReplyChannelHandle(
661     Message* message, const BorrowedChannelHandle& handle) {
662   auto* state = static_cast<MessageState*>(message->GetState());
663   auto ref = state->PushChannelHandle(handle.Duplicate());
664   if (!ref)
665     return ref.error_status();
666   return MessageReply(message, ref.get());
667 }
668 
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)669 Status<void> Endpoint::MessageReplyChannelHandle(
670     Message* message, const RemoteChannelHandle& handle) {
671   return MessageReply(message, handle.value());
672 }
673 
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)674 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
675                                          size_t vector_length) {
676   auto* state = static_cast<MessageState*>(message->GetState());
677   return state->ReadData(vector, vector_length);
678 }
679 
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)680 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
681                                           size_t vector_length) {
682   auto* state = static_cast<MessageState*>(message->GetState());
683   return state->WriteData(vector, vector_length);
684 }
685 
PushFileHandle(Message * message,const LocalHandle & handle)686 Status<FileReference> Endpoint::PushFileHandle(Message* message,
687                                                const LocalHandle& handle) {
688   auto* state = static_cast<MessageState*>(message->GetState());
689   return state->PushFileHandle(handle.Borrow());
690 }
691 
PushFileHandle(Message * message,const BorrowedHandle & handle)692 Status<FileReference> Endpoint::PushFileHandle(Message* message,
693                                                const BorrowedHandle& handle) {
694   auto* state = static_cast<MessageState*>(message->GetState());
695   return state->PushFileHandle(handle.Duplicate());
696 }
697 
PushFileHandle(Message *,const RemoteHandle & handle)698 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
699                                                const RemoteHandle& handle) {
700   return handle.Get();
701 }
702 
PushChannelHandle(Message * message,const LocalChannelHandle & handle)703 Status<ChannelReference> Endpoint::PushChannelHandle(
704     Message* message, const LocalChannelHandle& handle) {
705   auto* state = static_cast<MessageState*>(message->GetState());
706   return state->PushChannelHandle(handle.Borrow());
707 }
708 
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)709 Status<ChannelReference> Endpoint::PushChannelHandle(
710     Message* message, const BorrowedChannelHandle& handle) {
711   auto* state = static_cast<MessageState*>(message->GetState());
712   return state->PushChannelHandle(handle.Duplicate());
713 }
714 
PushChannelHandle(Message *,const RemoteChannelHandle & handle)715 Status<ChannelReference> Endpoint::PushChannelHandle(
716     Message* /*message*/, const RemoteChannelHandle& handle) {
717   return handle.value();
718 }
719 
GetFileHandle(Message * message,FileReference ref) const720 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
721   LocalHandle handle;
722   auto* state = static_cast<MessageState*>(message->GetState());
723   state->GetLocalFileHandle(ref, &handle);
724   return handle;
725 }
726 
GetChannelHandle(Message * message,ChannelReference ref) const727 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
728                                               ChannelReference ref) const {
729   LocalChannelHandle handle;
730   auto* state = static_cast<MessageState*>(message->GetState());
731   state->GetLocalChannelHandle(ref, &handle);
732   return handle;
733 }
734 
Cancel()735 Status<void> Endpoint::Cancel() {
736   if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
737     return ErrorStatus{errno};
738   return {};
739 }
740 
Create(const std::string & endpoint_path,mode_t,bool blocking)741 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
742                                            mode_t /*unused_mode*/,
743                                            bool blocking) {
744   return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
745 }
746 
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)747 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
748     const std::string& endpoint_path, bool blocking) {
749   return std::unique_ptr<Endpoint>(
750       new Endpoint(endpoint_path, blocking, false));
751 }
752 
CreateFromSocketFd(LocalHandle socket_fd)753 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
754   return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
755 }
756 
RegisterNewChannelForTests(LocalHandle channel_fd)757 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
758   int optval = 1;
759   if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
760                  sizeof(optval)) == -1) {
761     ALOGE(
762         "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
763         "of the credentials for channel %d: %s",
764         channel_fd.Get(), strerror(errno));
765     return ErrorStatus(errno);
766   }
767   return OnNewChannel(std::move(channel_fd));
768 }
769 
770 }  // namespace uds
771 }  // namespace pdx
772 }  // namespace android
773