1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre_host/socket_server.h"
18
19 #include <poll.h>
20
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27
28 #include <cutils/sockets.h>
29
30 #include "chre_host/log.h"
31
32 namespace android {
33 namespace chre {
34
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36
37 namespace {
38
maskAllSignals()39 void maskAllSignals() {
40 sigset_t signalMask;
41 sigfillset(&signalMask);
42 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
43 LOG_ERROR("Couldn't mask all signals", errno);
44 }
45 }
46
maskAllSignalsExceptIntAndTerm()47 void maskAllSignalsExceptIntAndTerm() {
48 sigset_t signalMask;
49 sigfillset(&signalMask);
50 sigdelset(&signalMask, SIGINT);
51 sigdelset(&signalMask, SIGTERM);
52 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
53 LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
54 }
55 }
56
57 } // anonymous namespace
58
SocketServer()59 SocketServer::SocketServer() {
60 // Initialize the socket fds field for all inactive client slots to -1, so
61 // poll skips over it, and we don't attempt to send on it
62 for (size_t i = 1; i <= kMaxActiveClients; i++) {
63 mPollFds[i].fd = -1;
64 mPollFds[i].events = POLLIN;
65 }
66 }
67
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)68 void SocketServer::run(const char *socketName, bool allowSocketCreation,
69 ClientMessageCallback clientMessageCallback) {
70 mClientMessageCallback = clientMessageCallback;
71
72 mSockFd = android_get_control_socket(socketName);
73 if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
74 LOGI("Didn't inherit socket, creating...");
75 mSockFd = socket_local_server(socketName,
76 ANDROID_SOCKET_NAMESPACE_RESERVED,
77 SOCK_SEQPACKET);
78 }
79
80 if (mSockFd == INVALID_SOCKET) {
81 LOGE("Couldn't get/create socket");
82 } else {
83 int ret = listen(mSockFd, kMaxPendingConnectionRequests);
84 if (ret < 0) {
85 LOG_ERROR("Couldn't listen on socket", errno);
86 } else {
87 serviceSocket();
88 }
89
90 {
91 std::lock_guard<std::mutex> lock(mClientsMutex);
92 for (const auto& pair : mClients) {
93 int clientSocket = pair.first;
94 if (close(clientSocket) != 0) {
95 LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
96 pair.second.clientId, strerror(errno));
97 }
98 }
99 mClients.clear();
100 }
101 close(mSockFd);
102 }
103 }
104
sendToAllClients(const void * data,size_t length)105 void SocketServer::sendToAllClients(const void *data, size_t length) {
106 std::lock_guard<std::mutex> lock(mClientsMutex);
107
108 int deliveredCount = 0;
109 for (const auto& pair : mClients) {
110 int clientSocket = pair.first;
111 uint16_t clientId = pair.second.clientId;
112 if (sendToClientSocket(data, length, clientSocket, clientId)) {
113 deliveredCount++;
114 } else if (errno == EINTR) {
115 // Exit early if we were interrupted - we should only get this for
116 // SIGINT/SIGTERM, so we should exit quickly
117 break;
118 }
119 }
120
121 if (deliveredCount == 0) {
122 LOGW("Got message but didn't deliver to any clients");
123 }
124 }
125
sendToClientById(const void * data,size_t length,uint16_t clientId)126 bool SocketServer::sendToClientById(const void *data, size_t length,
127 uint16_t clientId) {
128 std::lock_guard<std::mutex> lock(mClientsMutex);
129
130 bool sent = false;
131 for (const auto& pair : mClients) {
132 uint16_t thisClientId = pair.second.clientId;
133 if (thisClientId == clientId) {
134 int clientSocket = pair.first;
135 sent = sendToClientSocket(data, length, clientSocket, thisClientId);
136 break;
137 }
138 }
139
140 return sent;
141 }
142
acceptClientConnection()143 void SocketServer::acceptClientConnection() {
144 int clientSocket = accept(mSockFd, NULL, NULL);
145 if (clientSocket < 0) {
146 LOG_ERROR("Couldn't accept client connection", errno);
147 } else if (mClients.size() >= kMaxActiveClients) {
148 LOGW("Rejecting client request - maximum number of clients reached");
149 close(clientSocket);
150 } else {
151 ClientData clientData;
152 clientData.clientId = mNextClientId++;
153
154 // We currently don't handle wraparound - if we're getting this many
155 // connects/disconnects, then something is wrong.
156 // TODO: can handle this properly by iterating over the existing clients to
157 // avoid a conflict.
158 if (clientData.clientId == 0) {
159 LOGE("Couldn't allocate client ID");
160 std::exit(-1);
161 }
162
163 bool slotFound = false;
164 for (size_t i = 1; i <= kMaxActiveClients; i++) {
165 if (mPollFds[i].fd < 0) {
166 mPollFds[i].fd = clientSocket;
167 slotFound = true;
168 break;
169 }
170 }
171
172 if (!slotFound) {
173 LOGE("Couldn't find slot for client!");
174 assert(slotFound);
175 close(clientSocket);
176 } else {
177 {
178 std::lock_guard<std::mutex> lock(mClientsMutex);
179 mClients[clientSocket] = clientData;
180 }
181 LOGI("Accepted new client connection (count %zu), assigned client ID %"
182 PRIu16, mClients.size(), clientData.clientId);
183 }
184 }
185 }
186
handleClientData(int clientSocket)187 void SocketServer::handleClientData(int clientSocket) {
188 const ClientData& clientData = mClients[clientSocket];
189 uint16_t clientId = clientData.clientId;
190
191 uint8_t buffer[kMaxPacketSize];
192 ssize_t packetSize = recv(clientSocket, buffer, sizeof(buffer), MSG_DONTWAIT);
193 if (packetSize < 0) {
194 LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
195 strerror(errno));
196 } else if (packetSize == 0) {
197 LOGI("Client %" PRIu16 " disconnected", clientId);
198 disconnectClient(clientSocket);
199 } else {
200 LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
201 mClientMessageCallback(clientId, buffer, packetSize);
202 }
203 }
204
disconnectClient(int clientSocket)205 void SocketServer::disconnectClient(int clientSocket) {
206 {
207 std::lock_guard<std::mutex> lock(mClientsMutex);
208 mClients.erase(clientSocket);
209 }
210 close(clientSocket);
211
212 bool removed = false;
213 for (size_t i = 1; i <= kMaxActiveClients; i++) {
214 if (mPollFds[i].fd == clientSocket) {
215 mPollFds[i].fd = -1;
216 removed = true;
217 break;
218 }
219 }
220
221 if (!removed) {
222 LOGE("Out of sync");
223 assert(removed);
224 }
225 }
226
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)227 bool SocketServer::sendToClientSocket(const void *data, size_t length,
228 int clientSocket, uint16_t clientId) {
229 errno = 0;
230 ssize_t bytesSent = send(clientSocket, data, length, 0);
231 if (bytesSent < 0) {
232 LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
233 length, clientId, strerror(errno));
234 } else if (bytesSent == 0) {
235 LOGW("Client %" PRIu16 " disconnected before message could be delivered",
236 clientId);
237 } else {
238 LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
239 clientId);
240 }
241
242 return (bytesSent > 0);
243 }
244
serviceSocket()245 void SocketServer::serviceSocket() {
246 constexpr size_t kListenIndex = 0;
247 static_assert(kListenIndex == 0, "Code assumes that the first index is "
248 "always the listen socket");
249
250 mPollFds[kListenIndex].fd = mSockFd;
251 mPollFds[kListenIndex].events = POLLIN;
252
253 // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
254 // and ignore other signals
255 sigset_t signalMask;
256 sigfillset(&signalMask);
257 sigdelset(&signalMask, SIGINT);
258 sigdelset(&signalMask, SIGTERM);
259
260 // Masking signals here ensure that after this point, we won't handle INT/TERM
261 // until after we call into ppoll()
262 maskAllSignals();
263 std::signal(SIGINT, signalHandler);
264 std::signal(SIGTERM, signalHandler);
265
266 LOGI("Ready to accept connections");
267 while (!sSignalReceived) {
268 int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
269 maskAllSignalsExceptIntAndTerm();
270 if (ret == -1) {
271 LOGI("Exiting poll loop: %s", strerror(errno));
272 break;
273 }
274
275 if (mPollFds[kListenIndex].revents & POLLIN) {
276 acceptClientConnection();
277 }
278
279 for (size_t i = 1; i <= kMaxActiveClients; i++) {
280 if (mPollFds[i].fd < 0) {
281 continue;
282 }
283
284 if (mPollFds[i].revents & POLLIN) {
285 handleClientData(mPollFds[i].fd);
286 }
287 }
288
289 // Mask all signals to ensure that sSignalReceived can't become true between
290 // checking it in the while condition and calling into ppoll()
291 maskAllSignals();
292 }
293 }
294
signalHandler(int signal)295 void SocketServer::signalHandler(int signal) {
296 LOGD("Caught signal %d", signal);
297 sSignalReceived = true;
298 }
299
300 } // namespace chre
301 } // namespace android
302