1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_server.h"
18 
19 #include <poll.h>
20 
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27 
28 #include <cutils/sockets.h>
29 
30 #include "chre_host/log.h"
31 
32 namespace android {
33 namespace chre {
34 
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36 
37 namespace {
38 
maskAllSignals()39 void maskAllSignals() {
40   sigset_t signalMask;
41   sigfillset(&signalMask);
42   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
43     LOG_ERROR("Couldn't mask all signals", errno);
44   }
45 }
46 
maskAllSignalsExceptIntAndTerm()47 void maskAllSignalsExceptIntAndTerm() {
48   sigset_t signalMask;
49   sigfillset(&signalMask);
50   sigdelset(&signalMask, SIGINT);
51   sigdelset(&signalMask, SIGTERM);
52   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
53     LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
54   }
55 }
56 
57 }  // anonymous namespace
58 
SocketServer()59 SocketServer::SocketServer() {
60   // Initialize the socket fds field for all inactive client slots to -1, so
61   // poll skips over it, and we don't attempt to send on it
62   for (size_t i = 1; i <= kMaxActiveClients; i++) {
63     mPollFds[i].fd = -1;
64     mPollFds[i].events = POLLIN;
65   }
66 }
67 
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)68 void SocketServer::run(const char *socketName, bool allowSocketCreation,
69                        ClientMessageCallback clientMessageCallback) {
70   mClientMessageCallback = clientMessageCallback;
71 
72   mSockFd = android_get_control_socket(socketName);
73   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
74     LOGI("Didn't inherit socket, creating...");
75     mSockFd = socket_local_server(socketName,
76                                   ANDROID_SOCKET_NAMESPACE_RESERVED,
77                                   SOCK_SEQPACKET);
78   }
79 
80   if (mSockFd == INVALID_SOCKET) {
81     LOGE("Couldn't get/create socket");
82   } else {
83     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
84     if (ret < 0) {
85       LOG_ERROR("Couldn't listen on socket", errno);
86     } else {
87       serviceSocket();
88     }
89 
90     {
91       std::lock_guard<std::mutex> lock(mClientsMutex);
92       for (const auto& pair : mClients) {
93         int clientSocket = pair.first;
94         if (close(clientSocket) != 0) {
95           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
96                pair.second.clientId, strerror(errno));
97         }
98       }
99       mClients.clear();
100     }
101     close(mSockFd);
102   }
103 }
104 
sendToAllClients(const void * data,size_t length)105 void SocketServer::sendToAllClients(const void *data, size_t length) {
106   std::lock_guard<std::mutex> lock(mClientsMutex);
107 
108   int deliveredCount = 0;
109   for (const auto& pair : mClients) {
110     int clientSocket = pair.first;
111     uint16_t clientId = pair.second.clientId;
112     if (sendToClientSocket(data, length, clientSocket, clientId)) {
113       deliveredCount++;
114     } else if (errno == EINTR) {
115       // Exit early if we were interrupted - we should only get this for
116       // SIGINT/SIGTERM, so we should exit quickly
117       break;
118     }
119   }
120 
121   if (deliveredCount == 0) {
122     LOGW("Got message but didn't deliver to any clients");
123   }
124 }
125 
sendToClientById(const void * data,size_t length,uint16_t clientId)126 bool SocketServer::sendToClientById(const void *data, size_t length,
127                                     uint16_t clientId) {
128   std::lock_guard<std::mutex> lock(mClientsMutex);
129 
130   bool sent = false;
131   for (const auto& pair : mClients) {
132     uint16_t thisClientId = pair.second.clientId;
133     if (thisClientId == clientId) {
134       int clientSocket = pair.first;
135       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
136       break;
137     }
138   }
139 
140   return sent;
141 }
142 
acceptClientConnection()143 void SocketServer::acceptClientConnection() {
144   int clientSocket = accept(mSockFd, NULL, NULL);
145   if (clientSocket < 0) {
146     LOG_ERROR("Couldn't accept client connection", errno);
147   } else if (mClients.size() >= kMaxActiveClients) {
148     LOGW("Rejecting client request - maximum number of clients reached");
149     close(clientSocket);
150   } else {
151     ClientData clientData;
152     clientData.clientId = mNextClientId++;
153 
154     // We currently don't handle wraparound - if we're getting this many
155     // connects/disconnects, then something is wrong.
156     // TODO: can handle this properly by iterating over the existing clients to
157     // avoid a conflict.
158     if (clientData.clientId == 0) {
159       LOGE("Couldn't allocate client ID");
160       std::exit(-1);
161     }
162 
163     bool slotFound = false;
164     for (size_t i = 1; i <= kMaxActiveClients; i++) {
165       if (mPollFds[i].fd < 0) {
166         mPollFds[i].fd = clientSocket;
167         slotFound = true;
168         break;
169       }
170     }
171 
172     if (!slotFound) {
173       LOGE("Couldn't find slot for client!");
174       assert(slotFound);
175       close(clientSocket);
176     } else {
177       {
178         std::lock_guard<std::mutex> lock(mClientsMutex);
179         mClients[clientSocket] = clientData;
180       }
181       LOGI("Accepted new client connection (count %zu), assigned client ID %"
182            PRIu16, mClients.size(), clientData.clientId);
183     }
184   }
185 }
186 
handleClientData(int clientSocket)187 void SocketServer::handleClientData(int clientSocket) {
188   const ClientData& clientData = mClients[clientSocket];
189   uint16_t clientId = clientData.clientId;
190 
191   uint8_t buffer[kMaxPacketSize];
192   ssize_t packetSize = recv(clientSocket, buffer, sizeof(buffer), MSG_DONTWAIT);
193   if (packetSize < 0) {
194     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
195          strerror(errno));
196   } else if (packetSize == 0) {
197     LOGI("Client %" PRIu16 " disconnected", clientId);
198     disconnectClient(clientSocket);
199   } else {
200     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
201     mClientMessageCallback(clientId, buffer, packetSize);
202   }
203 }
204 
disconnectClient(int clientSocket)205 void SocketServer::disconnectClient(int clientSocket) {
206   {
207     std::lock_guard<std::mutex> lock(mClientsMutex);
208     mClients.erase(clientSocket);
209   }
210   close(clientSocket);
211 
212   bool removed = false;
213   for (size_t i = 1; i <= kMaxActiveClients; i++) {
214     if (mPollFds[i].fd == clientSocket) {
215       mPollFds[i].fd = -1;
216       removed = true;
217       break;
218     }
219   }
220 
221   if (!removed) {
222     LOGE("Out of sync");
223     assert(removed);
224   }
225 }
226 
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)227 bool SocketServer::sendToClientSocket(const void *data, size_t length,
228                                       int clientSocket, uint16_t clientId) {
229   errno = 0;
230   ssize_t bytesSent = send(clientSocket, data, length, 0);
231   if (bytesSent < 0) {
232     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
233          length, clientId, strerror(errno));
234   } else if (bytesSent == 0) {
235     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
236          clientId);
237   } else {
238     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
239          clientId);
240   }
241 
242   return (bytesSent > 0);
243 }
244 
serviceSocket()245 void SocketServer::serviceSocket() {
246   constexpr size_t kListenIndex = 0;
247   static_assert(kListenIndex == 0, "Code assumes that the first index is "
248                 "always the listen socket");
249 
250   mPollFds[kListenIndex].fd = mSockFd;
251   mPollFds[kListenIndex].events = POLLIN;
252 
253   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
254   // and ignore other signals
255   sigset_t signalMask;
256   sigfillset(&signalMask);
257   sigdelset(&signalMask, SIGINT);
258   sigdelset(&signalMask, SIGTERM);
259 
260   // Masking signals here ensure that after this point, we won't handle INT/TERM
261   // until after we call into ppoll()
262   maskAllSignals();
263   std::signal(SIGINT, signalHandler);
264   std::signal(SIGTERM, signalHandler);
265 
266   LOGI("Ready to accept connections");
267   while (!sSignalReceived) {
268     int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
269     maskAllSignalsExceptIntAndTerm();
270     if (ret == -1) {
271       LOGI("Exiting poll loop: %s", strerror(errno));
272       break;
273     }
274 
275     if (mPollFds[kListenIndex].revents & POLLIN) {
276       acceptClientConnection();
277     }
278 
279     for (size_t i = 1; i <= kMaxActiveClients; i++) {
280       if (mPollFds[i].fd < 0) {
281         continue;
282       }
283 
284       if (mPollFds[i].revents & POLLIN) {
285         handleClientData(mPollFds[i].fd);
286       }
287     }
288 
289     // Mask all signals to ensure that sSignalReceived can't become true between
290     // checking it in the while condition and calling into ppoll()
291     maskAllSignals();
292   }
293 }
294 
signalHandler(int signal)295 void SocketServer::signalHandler(int signal) {
296   LOGD("Caught signal %d", signal);
297   sSignalReceived = true;
298 }
299 
300 }  // namespace chre
301 }  // namespace android
302