1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_client.h"
18 
19 #include <inttypes.h>
20 
21 #include <string.h>
22 #include <unistd.h>
23 
24 #include <chrono>
25 
26 #include <cutils/sockets.h>
27 #include <sys/socket.h>
28 #include <utils/RefBase.h>
29 #include <utils/StrongPointer.h>
30 
31 #include "chre_host/log.h"
32 
33 namespace android {
34 namespace chre {
35 
SocketClient()36 SocketClient::SocketClient() {
37   std::atomic_init(&mSockFd, INVALID_SOCKET);
38 }
39 
~SocketClient()40 SocketClient::~SocketClient() {
41   disconnect();
42 }
43 
connect(const char * socketName,const sp<ICallbacks> & callbacks)44 bool SocketClient::connect(const char *socketName,
45                            const sp<ICallbacks> &callbacks) {
46   return doConnect(socketName, callbacks, false /* connectInBackground */);
47 }
48 
connectInBackground(const char * socketName,const sp<ICallbacks> & callbacks)49 bool SocketClient::connectInBackground(const char *socketName,
50                                        const sp<ICallbacks> &callbacks) {
51   return doConnect(socketName, callbacks, true /* connectInBackground */);
52 }
53 
disconnect()54 void SocketClient::disconnect() {
55   if (inReceiveThread()) {
56     LOGE("disconnect() can't be called from a receive thread callback");
57   } else if (receiveThreadRunning()) {
58     // Inform the RX thread that we're requesting a shutdown, breaking it out of
59     // the retry wait if it's currently blocked there
60     {
61       std::lock_guard<std::mutex> lock(mShutdownMutex);
62       mGracefulShutdown = true;
63     }
64     mShutdownCond.notify_all();
65 
66     // Invalidate the socket (will kick the RX thread out of recv if it's
67     // currently blocked there)
68     if (mSockFd != INVALID_SOCKET && shutdown(mSockFd, SHUT_RDWR) != 0) {
69       LOG_ERROR("Couldn't shut down socket", errno);
70     }
71 
72     if (mRxThread.joinable()) {
73       LOGD("Waiting for RX thread to exit");
74       mRxThread.join();
75     }
76   }
77 }
78 
isConnected() const79 bool SocketClient::isConnected() const {
80   return (mSockFd != INVALID_SOCKET);
81 }
82 
sendMessage(const void * data,size_t length)83 bool SocketClient::sendMessage(const void *data, size_t length) {
84   bool success = false;
85 
86   if (mSockFd == INVALID_SOCKET) {
87     LOGW("Tried sending a message, but don't have a valid socket handle");
88   } else {
89     ssize_t bytesSent = send(mSockFd, data, length, 0);
90     if (bytesSent < 0) {
91       LOGE("Failed to send %zu bytes of data: %s", length, strerror(errno));
92     } else if (bytesSent == 0) {
93       LOGW("Failed to send data; remote side disconnected");
94     } else if (static_cast<size_t>(bytesSent) != length) {
95       LOGW("Truncated packet, tried sending %zu bytes, only %zd went through",
96            length, bytesSent);
97     } else {
98       success = true;
99     }
100   }
101 
102   return success;
103 }
104 
doConnect(const char * socketName,const sp<ICallbacks> & callbacks,bool connectInBackground)105 bool SocketClient::doConnect(const char *socketName,
106                              const sp<ICallbacks> &callbacks,
107                              bool connectInBackground) {
108   bool success = false;
109   if (inReceiveThread()) {
110     LOGE("Can't attempt to connect from a receive thread callback");
111   } else {
112     if (receiveThreadRunning()) {
113       LOGW("Re-connecting socket with implicit disconnect");
114       disconnect();
115     }
116 
117     size_t socketNameLen =
118         strlcpy(mSocketName, socketName, sizeof(mSocketName));
119     if (socketNameLen >= sizeof(mSocketName)) {
120       LOGE("Socket name length parameter is too long (%zu, max %zu)",
121            socketNameLen, sizeof(mSocketName));
122     } else if (callbacks == nullptr) {
123       LOGE("Callbacks parameter must be provided");
124     } else if (connectInBackground || tryConnect()) {
125       mGracefulShutdown = false;
126       mCallbacks = callbacks;
127       mRxThread = std::thread([this]() { receiveThread(); });
128       success = true;
129     }
130   }
131 
132   return success;
133 }
134 
inReceiveThread() const135 bool SocketClient::inReceiveThread() const {
136   return (std::this_thread::get_id() == mRxThread.get_id());
137 }
138 
receiveThread()139 void SocketClient::receiveThread() {
140   constexpr size_t kReceiveBufferSize = 4096;
141   uint8_t buffer[kReceiveBufferSize];
142 
143   LOGV("Receive thread started");
144   while (!mGracefulShutdown && (mSockFd != INVALID_SOCKET || reconnect())) {
145     while (!mGracefulShutdown) {
146       ssize_t bytesReceived = recv(mSockFd, buffer, sizeof(buffer), 0);
147       if (bytesReceived < 0) {
148         LOG_ERROR("Exiting RX thread", errno);
149         break;
150       } else if (bytesReceived == 0) {
151         if (!mGracefulShutdown) {
152           LOGI("Socket disconnected on remote end");
153           mCallbacks->onDisconnected();
154         }
155         break;
156       }
157 
158       mCallbacks->onMessageReceived(buffer, bytesReceived);
159     }
160 
161     if (close(mSockFd) != 0) {
162       LOG_ERROR("Couldn't close socket", errno);
163     }
164     mSockFd = INVALID_SOCKET;
165   }
166 
167   if (!mGracefulShutdown) {
168     mCallbacks->onConnectionAborted();
169   }
170 
171   mCallbacks.clear();
172   LOGV("Exiting receive thread");
173 }
174 
receiveThreadRunning() const175 bool SocketClient::receiveThreadRunning() const {
176   return mRxThread.joinable();
177 }
178 
reconnect()179 bool SocketClient::reconnect() {
180   constexpr auto kMinDelay = std::chrono::duration<int32_t, std::milli>(250);
181   constexpr auto kMaxDelay = std::chrono::minutes(5);
182   // Try reconnecting at initial delay this many times before backing off
183   constexpr unsigned int kExponentialBackoffDelay =
184       std::chrono::seconds(10) / kMinDelay;
185   // Give up after this many tries (~2.5 hours)
186   constexpr unsigned int kRetryLimit = kExponentialBackoffDelay + 40;
187   auto delay = kMinDelay;
188   unsigned int retryCount = 0;
189 
190   while (retryCount++ < kRetryLimit) {
191     {
192       std::unique_lock<std::mutex> lock(mShutdownMutex);
193       mShutdownCond.wait_for(lock, delay,
194                              [this]() { return mGracefulShutdown.load(); });
195       if (mGracefulShutdown) {
196         break;
197       }
198     }
199 
200     bool suppressErrorLogs = (delay == kMinDelay);
201     if (!tryConnect(suppressErrorLogs)) {
202       if (!suppressErrorLogs) {
203         LOGW("Failed to (re)connect, next try in %" PRId32 " ms",
204              delay.count());
205       }
206       if (retryCount > kExponentialBackoffDelay) {
207         delay *= 2;
208       }
209       if (delay > kMaxDelay) {
210         delay = kMaxDelay;
211       }
212     } else {
213       LOGD("Successfully (re)connected");
214       mCallbacks->onConnected();
215       return true;
216     }
217   }
218 
219   return false;
220 }
221 
tryConnect(bool suppressErrorLogs)222 bool SocketClient::tryConnect(bool suppressErrorLogs) {
223   bool success = false;
224 
225   errno = 0;
226   int sockFd = socket(AF_LOCAL, SOCK_SEQPACKET, 0);
227   if (sockFd >= 0) {
228     // Set the send buffer size to 2MB to allow plenty of room for nanoapp
229     // loading
230     int sndbuf = 2 * 1024 * 1024;
231     // Normally, send() should effectively return immediately, but in the event
232     // that we get blocked due to flow control, don't stay blocked for more than
233     // 3 seconds
234     struct timeval timeout = {
235         .tv_sec = 3,
236         .tv_usec = 0,
237     };
238     int ret;
239 
240     if ((ret = setsockopt(sockFd, SOL_SOCKET, SO_SNDBUF, &sndbuf,
241                           sizeof(sndbuf))) != 0) {
242       if (!suppressErrorLogs) {
243         LOGE("Failed to set SO_SNDBUF to %d: %s", sndbuf, strerror(errno));
244       }
245     } else if ((ret = setsockopt(sockFd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
246                                  sizeof(timeout))) != 0) {
247       if (!suppressErrorLogs) {
248         LOGE("Failed to set SO_SNDTIMEO: %s", strerror(errno));
249       }
250     } else {
251       mSockFd = socket_local_client_connect(sockFd, mSocketName,
252                                             ANDROID_SOCKET_NAMESPACE_RESERVED,
253                                             SOCK_SEQPACKET);
254       if (mSockFd != INVALID_SOCKET) {
255         success = true;
256       } else if (!suppressErrorLogs) {
257         LOGE("Couldn't connect client socket to '%s': %s", mSocketName,
258              strerror(errno));
259       }
260     }
261 
262     if (!success) {
263       close(sockFd);
264     }
265   } else if (!suppressErrorLogs) {
266     LOGE("Couldn't create local socket: %s", strerror(errno));
267   }
268 
269   return success;
270 }
271 
272 }  // namespace chre
273 }  // namespace android
274