1 #include "uds/ipc_helper.h"
2 
3 #include <alloca.h>
4 #include <errno.h>
5 #include <log/log.h>
6 #include <poll.h>
7 #include <string.h>
8 #include <sys/inotify.h>
9 #include <sys/param.h>
10 #include <sys/socket.h>
11 
12 #include <algorithm>
13 
14 #include <pdx/service.h>
15 #include <pdx/utility.h>
16 
17 namespace android {
18 namespace pdx {
19 namespace uds {
20 
21 namespace {
22 
23 // Default implementations of Send/Receive interfaces to use standard socket
24 // send/sendmsg/recv/recvmsg functions.
25 class SocketSender : public SendInterface {
26  public:
Send(int socket_fd,const void * data,size_t size,int flags)27   ssize_t Send(int socket_fd, const void* data, size_t size,
28                int flags) override {
29     return send(socket_fd, data, size, flags);
30   }
SendMessage(int socket_fd,const msghdr * msg,int flags)31   ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
32     return sendmsg(socket_fd, msg, flags);
33   }
34 } g_socket_sender;
35 
36 class SocketReceiver : public RecvInterface {
37  public:
Receive(int socket_fd,void * data,size_t size,int flags)38   ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
39     return recv(socket_fd, data, size, flags);
40   }
ReceiveMessage(int socket_fd,msghdr * msg,int flags)41   ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
42     return recvmsg(socket_fd, msg, flags);
43   }
44 } g_socket_receiver;
45 
46 }  // anonymous namespace
47 
48 // Helper wrappers around send()/sendmsg() which repeat send() calls on data
49 // that was not sent with the initial call to send/sendmsg. This is important to
50 // handle transmissions interrupted by signals.
SendAll(SendInterface * sender,const BorrowedHandle & socket_fd,const void * data,size_t size)51 Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
52                      const void* data, size_t size) {
53   Status<void> ret;
54   const uint8_t* ptr = static_cast<const uint8_t*>(data);
55   while (size > 0) {
56     ssize_t size_written =
57         RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
58     if (size_written < 0) {
59       ret.SetError(errno);
60       ALOGE("SendAll: Failed to send data over socket: %s",
61             ret.GetErrorMessage().c_str());
62       break;
63     }
64     size -= size_written;
65     ptr += size_written;
66   }
67   return ret;
68 }
69 
SendMsgAll(SendInterface * sender,const BorrowedHandle & socket_fd,const msghdr * msg)70 Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
71                         const msghdr* msg) {
72   Status<void> ret;
73   ssize_t sent_size =
74       RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
75   if (sent_size < 0) {
76     ret.SetError(errno);
77     ALOGE("SendMsgAll: Failed to send data over socket: %s",
78           ret.GetErrorMessage().c_str());
79     return ret;
80   }
81 
82   ssize_t chunk_start_offset = 0;
83   for (size_t i = 0; i < msg->msg_iovlen; i++) {
84     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
85     if (sent_size < chunk_end_offset) {
86       size_t offset_within_chunk = sent_size - chunk_start_offset;
87       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
88       const uint8_t* chunk_base =
89           static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
90       ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
91                     data_size);
92       if (!ret)
93         break;
94       sent_size += data_size;
95     }
96     chunk_start_offset = chunk_end_offset;
97   }
98   return ret;
99 }
100 
101 // Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
102 // that was not received with the initial call to recvmsg(). This is important
103 // to handle transmissions interrupted by signals as well as the case when
104 // initial data did not arrive in a single chunk over the socket (e.g. socket
105 // buffer was full at the time of transmission, and only portion of initial
106 // message was sent and the rest was blocked until the buffer was cleared by the
107 // receiving side).
RecvMsgAll(RecvInterface * receiver,const BorrowedHandle & socket_fd,msghdr * msg)108 Status<void> RecvMsgAll(RecvInterface* receiver,
109                         const BorrowedHandle& socket_fd, msghdr* msg) {
110   Status<void> ret;
111   ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
112       socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
113   if (size_read < 0) {
114     ret.SetError(errno);
115     ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
116           ret.GetErrorMessage().c_str());
117     return ret;
118   } else if (size_read == 0) {
119     ret.SetError(ESHUTDOWN);
120     ALOGW("RecvMsgAll: Socket has been shut down");
121     return ret;
122   }
123 
124   ssize_t chunk_start_offset = 0;
125   for (size_t i = 0; i < msg->msg_iovlen; i++) {
126     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
127     if (size_read < chunk_end_offset) {
128       size_t offset_within_chunk = size_read - chunk_start_offset;
129       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
130       uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
131       ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
132                     data_size);
133       if (!ret)
134         break;
135       size_read += data_size;
136     }
137     chunk_start_offset = chunk_end_offset;
138   }
139   return ret;
140 }
141 
RecvAll(RecvInterface * receiver,const BorrowedHandle & socket_fd,void * data,size_t size)142 Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
143                      void* data, size_t size) {
144   Status<void> ret;
145   uint8_t* ptr = static_cast<uint8_t*>(data);
146   while (size > 0) {
147     ssize_t size_read = RETRY_EINTR(receiver->Receive(
148         socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
149     if (size_read < 0) {
150       ret.SetError(errno);
151       ALOGE("RecvAll: Failed to receive data from socket: %s",
152             ret.GetErrorMessage().c_str());
153       break;
154     } else if (size_read == 0) {
155       ret.SetError(ESHUTDOWN);
156       ALOGW("RecvAll: Socket has been shut down");
157       break;
158     }
159     size -= size_read;
160     ptr += size_read;
161   }
162   return ret;
163 }
164 
165 uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
166 
167 struct MessagePreamble {
168   uint32_t magic{0};
169   uint32_t data_size{0};
170   uint32_t fd_count{0};
171 };
172 
Send(const BorrowedHandle & socket_fd)173 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
174   return Send(socket_fd, nullptr);
175 }
176 
Send(const BorrowedHandle & socket_fd,const ucred * cred)177 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
178                                const ucred* cred) {
179   SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
180   MessagePreamble preamble;
181   preamble.magic = kMagicPreamble;
182   preamble.data_size = buffer_.size();
183   preamble.fd_count = file_handles_.size();
184   Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
185   if (!ret)
186     return ret;
187 
188   msghdr msg = {};
189   iovec recv_vect = {buffer_.data(), buffer_.size()};
190   msg.msg_iov = &recv_vect;
191   msg.msg_iovlen = 1;
192 
193   if (cred || !file_handles_.empty()) {
194     const size_t fd_bytes = file_handles_.size() * sizeof(int);
195     msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
196                          (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
197     msg.msg_control = alloca(msg.msg_controllen);
198 
199     cmsghdr* control = CMSG_FIRSTHDR(&msg);
200     if (cred) {
201       control->cmsg_level = SOL_SOCKET;
202       control->cmsg_type = SCM_CREDENTIALS;
203       control->cmsg_len = CMSG_LEN(sizeof(ucred));
204       memcpy(CMSG_DATA(control), cred, sizeof(ucred));
205       control = CMSG_NXTHDR(&msg, control);
206     }
207 
208     if (fd_bytes) {
209       control->cmsg_level = SOL_SOCKET;
210       control->cmsg_type = SCM_RIGHTS;
211       control->cmsg_len = CMSG_LEN(fd_bytes);
212       memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
213     }
214   }
215 
216   return SendMsgAll(sender, socket_fd, &msg);
217 }
218 
219 // MessageWriter
GetNextWriteBufferSection(size_t size)220 void* SendPayload::GetNextWriteBufferSection(size_t size) {
221   return buffer_.grow_by(size);
222 }
223 
GetOutputResourceMapper()224 OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
225 
226 // OutputResourceMapper
PushFileHandle(const LocalHandle & handle)227 Status<FileReference> SendPayload::PushFileHandle(const LocalHandle& handle) {
228   if (handle) {
229     const int ref = file_handles_.size();
230     file_handles_.push_back(handle.Get());
231     return ref;
232   } else {
233     return handle.Get();
234   }
235 }
236 
PushFileHandle(const BorrowedHandle & handle)237 Status<FileReference> SendPayload::PushFileHandle(
238     const BorrowedHandle& handle) {
239   if (handle) {
240     const int ref = file_handles_.size();
241     file_handles_.push_back(handle.Get());
242     return ref;
243   } else {
244     return handle.Get();
245   }
246 }
247 
PushFileHandle(const RemoteHandle & handle)248 Status<FileReference> SendPayload::PushFileHandle(const RemoteHandle& handle) {
249   return handle.Get();
250 }
251 
PushChannelHandle(const LocalChannelHandle &)252 Status<ChannelReference> SendPayload::PushChannelHandle(
253     const LocalChannelHandle& /*handle*/) {
254   return ErrorStatus{EOPNOTSUPP};
255 }
PushChannelHandle(const BorrowedChannelHandle &)256 Status<ChannelReference> SendPayload::PushChannelHandle(
257     const BorrowedChannelHandle& /*handle*/) {
258   return ErrorStatus{EOPNOTSUPP};
259 }
PushChannelHandle(const RemoteChannelHandle &)260 Status<ChannelReference> SendPayload::PushChannelHandle(
261     const RemoteChannelHandle& /*handle*/) {
262   return ErrorStatus{EOPNOTSUPP};
263 }
264 
Receive(const BorrowedHandle & socket_fd)265 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
266   return Receive(socket_fd, nullptr);
267 }
268 
Receive(const BorrowedHandle & socket_fd,ucred * cred)269 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
270                                      ucred* cred) {
271   RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
272   MessagePreamble preamble;
273   Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
274   if (!ret)
275     return ret;
276 
277   if (preamble.magic != kMagicPreamble) {
278     ALOGE("ReceivePayload::Receive: Message header is invalid");
279     ret.SetError(EIO);
280     return ret;
281   }
282 
283   buffer_.resize(preamble.data_size);
284   file_handles_.clear();
285   read_pos_ = 0;
286 
287   msghdr msg = {};
288   iovec recv_vect = {buffer_.data(), buffer_.size()};
289   msg.msg_iov = &recv_vect;
290   msg.msg_iovlen = 1;
291 
292   if (cred || preamble.fd_count) {
293     const size_t receive_fd_bytes = preamble.fd_count * sizeof(int);
294     msg.msg_controllen =
295         (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
296         (receive_fd_bytes == 0 ? 0 : CMSG_SPACE(receive_fd_bytes));
297     msg.msg_control = alloca(msg.msg_controllen);
298   }
299 
300   ret = RecvMsgAll(receiver, socket_fd, &msg);
301   if (!ret)
302     return ret;
303 
304   bool cred_available = false;
305   file_handles_.reserve(preamble.fd_count);
306   cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
307   while (cmsg) {
308     if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
309         cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
310       cred_available = true;
311       memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
312     } else if (cmsg->cmsg_level == SOL_SOCKET &&
313                cmsg->cmsg_type == SCM_RIGHTS) {
314       socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
315       const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
316       size_t fd_count = payload_len / sizeof(int);
317       std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
318                      [](int fd) { return LocalHandle{fd}; });
319     }
320     cmsg = CMSG_NXTHDR(&msg, cmsg);
321   }
322 
323   if (cred && !cred_available) {
324     ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
325     ret.SetError(EIO);
326   }
327 
328   return ret;
329 }
330 
331 // MessageReader
GetNextReadBufferSection()332 MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
333   return {buffer_.data() + read_pos_, &*buffer_.end()};
334 }
335 
ConsumeReadBufferSectionData(const void * new_start)336 void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
337   read_pos_ = PointerDistance(new_start, buffer_.data());
338 }
339 
GetInputResourceMapper()340 InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
341 
342 // InputResourceMapper
GetFileHandle(FileReference ref,LocalHandle * handle)343 bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
344   if (ref < 0) {
345     *handle = LocalHandle{ref};
346     return true;
347   }
348   if (static_cast<size_t>(ref) > file_handles_.size())
349     return false;
350   *handle = std::move(file_handles_[ref]);
351   return true;
352 }
353 
GetChannelHandle(ChannelReference,LocalChannelHandle *)354 bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
355                                       LocalChannelHandle* /*handle*/) {
356   return false;
357 }
358 
SendData(const BorrowedHandle & socket_fd,const void * data,size_t size)359 Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
360                       size_t size) {
361   return SendAll(&g_socket_sender, socket_fd, data, size);
362 }
363 
SendDataVector(const BorrowedHandle & socket_fd,const iovec * data,size_t count)364 Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
365                             size_t count) {
366   msghdr msg = {};
367   msg.msg_iov = const_cast<iovec*>(data);
368   msg.msg_iovlen = count;
369   return SendMsgAll(&g_socket_sender, socket_fd, &msg);
370 }
371 
ReceiveData(const BorrowedHandle & socket_fd,void * data,size_t size)372 Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
373                          size_t size) {
374   return RecvAll(&g_socket_receiver, socket_fd, data, size);
375 }
376 
ReceiveDataVector(const BorrowedHandle & socket_fd,const iovec * data,size_t count)377 Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
378                                const iovec* data, size_t count) {
379   msghdr msg = {};
380   msg.msg_iov = const_cast<iovec*>(data);
381   msg.msg_iovlen = count;
382   return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
383 }
384 
CountVectorSize(const iovec * vector,size_t count)385 size_t CountVectorSize(const iovec* vector, size_t count) {
386   return std::accumulate(
387       vector, vector + count, size_t{0},
388       [](size_t size, const iovec& vec) { return size + vec.iov_len; });
389 }
390 
InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle> * request,int opcode,uint32_t send_len,uint32_t max_recv_len,bool is_impulse)391 void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
392                  int opcode, uint32_t send_len, uint32_t max_recv_len,
393                  bool is_impulse) {
394   request->op = opcode;
395   request->cred.pid = getpid();
396   request->cred.uid = geteuid();
397   request->cred.gid = getegid();
398   request->send_len = send_len;
399   request->max_recv_len = max_recv_len;
400   request->is_impulse = is_impulse;
401 }
402 
WaitForEndpoint(const std::string & endpoint_path,int64_t timeout_ms)403 Status<void> WaitForEndpoint(const std::string& endpoint_path,
404                              int64_t timeout_ms) {
405   // Endpoint path must be absolute.
406   if (endpoint_path.empty() || endpoint_path.front() != '/')
407     return ErrorStatus(EINVAL);
408 
409   // Create inotify fd.
410   LocalHandle fd{inotify_init()};
411   if (!fd)
412     return ErrorStatus(errno);
413 
414   // Set the inotify fd to non-blocking.
415   int ret = fcntl(fd.Get(), F_GETFL);
416   fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
417 
418   // Setup the pollfd.
419   pollfd pfd = {fd.Get(), POLLIN, 0};
420 
421   // Find locations of each path separator.
422   std::vector<size_t> separators{0};  // The path is absolute, so '/' is at #0.
423   size_t pos = endpoint_path.find('/', 1);
424   while (pos != std::string::npos) {
425     separators.push_back(pos);
426     pos = endpoint_path.find('/', pos + 1);
427   }
428   separators.push_back(endpoint_path.size());
429 
430   // Walk down the path, checking for existence and waiting if needed.
431   pos = 1;
432   size_t links = 0;
433   std::string current;
434   while (pos < separators.size() && links <= MAXSYMLINKS) {
435     std::string previous = current;
436     current = endpoint_path.substr(0, separators[pos]);
437 
438     // Check for existence; proceed to setup a watch if not.
439     if (access(current.c_str(), F_OK) < 0) {
440       if (errno != ENOENT)
441         return ErrorStatus(errno);
442 
443       // Extract the name of the path component to wait for.
444       std::string next = current.substr(
445           separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
446 
447       // Add a watch on the last existing directory we reach.
448       int wd = inotify_add_watch(
449           fd.Get(), previous.c_str(),
450           IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
451       if (wd < 0) {
452         if (errno != ENOENT)
453           return ErrorStatus(errno);
454         // Restart at the beginning if previous was deleted.
455         links = 0;
456         current.clear();
457         pos = 1;
458         continue;
459       }
460 
461       // Make sure current didn't get created before the watch was added.
462       ret = access(current.c_str(), F_OK);
463       if (ret < 0) {
464         if (errno != ENOENT)
465           return ErrorStatus(errno);
466 
467         bool exit_poll = false;
468         while (!exit_poll) {
469           // Wait for an event or timeout.
470           ret = poll(&pfd, 1, timeout_ms);
471           if (ret <= 0)
472             return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
473 
474           // Read events.
475           char buffer[sizeof(inotify_event) + NAME_MAX + 1];
476 
477           ret = read(fd.Get(), buffer, sizeof(buffer));
478           if (ret < 0) {
479             if (errno == EAGAIN || errno == EWOULDBLOCK)
480               continue;
481             else
482               return ErrorStatus(errno);
483           } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
484             return ErrorStatus(EIO);
485           }
486 
487           auto* event = reinterpret_cast<const inotify_event*>(buffer);
488           auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
489           while (event < end) {
490             std::string event_for;
491             if (event->len > 0)
492               event_for = event->name;
493 
494             if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
495               // See if this is the droid we're looking for.
496               if (next == event_for) {
497                 exit_poll = true;
498                 break;
499               }
500             } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
501               // Restart at the beginning if our watch dir is deleted.
502               links = 0;
503               current.clear();
504               pos = 0;
505               exit_poll = true;
506               break;
507             }
508 
509             event = reinterpret_cast<const inotify_event*>(AdvancePointer(
510                 event, sizeof(struct inotify_event) + event->len));
511           }  // while (event < end)
512         }    // while (!exit_poll)
513       }      // Current dir doesn't exist.
514       ret = inotify_rm_watch(fd.Get(), wd);
515       if (ret < 0 && errno != EINVAL)
516         return ErrorStatus(errno);
517     }  // if (access(current.c_str(), F_OK) < 0)
518 
519     // Check for symbolic link and update link count.
520     struct stat stat_buf;
521     ret = lstat(current.c_str(), &stat_buf);
522     if (ret < 0 && errno != ENOENT)
523       return ErrorStatus(errno);
524     else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
525       links++;
526     pos++;
527   }  // while (pos < separators.size() && links <= MAXSYMLINKS)
528 
529   return {};
530 }
531 
532 }  // namespace uds
533 }  // namespace pdx
534 }  // namespace android
535