1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_client.h"
18 
19 #include <inttypes.h>
20 
21 #include <string.h>
22 
23 #include <chrono>
24 
25 #include <cutils/sockets.h>
26 #include <utils/RefBase.h>
27 #include <utils/StrongPointer.h>
28 
29 #include "chre_host/log.h"
30 
31 namespace android {
32 namespace chre {
33 
SocketClient()34 SocketClient::SocketClient() {
35   std::atomic_init(&mSockFd, INVALID_SOCKET);
36 }
37 
~SocketClient()38 SocketClient::~SocketClient() {
39   disconnect();
40 }
41 
connect(const char * socketName,const sp<ICallbacks> & callbacks)42 bool SocketClient::connect(const char *socketName,
43                            const sp<ICallbacks>& callbacks) {
44   return doConnect(socketName, callbacks, false /* connectInBackground */);
45 }
46 
connectInBackground(const char * socketName,const sp<ICallbacks> & callbacks)47 bool SocketClient::connectInBackground(const char *socketName,
48                                        const sp<ICallbacks>& callbacks) {
49   return doConnect(socketName, callbacks, true /* connectInBackground */);
50 }
51 
disconnect()52 void SocketClient::disconnect() {
53   if (inReceiveThread()) {
54     LOGE("disconnect() can't be called from a receive thread callback");
55   } else if (receiveThreadRunning()) {
56     // Inform the RX thread that we're requesting a shutdown, breaking it out of
57     // the retry wait if it's currently blocked there
58     {
59       std::lock_guard<std::mutex> lock(mShutdownMutex);
60       mGracefulShutdown = true;
61     }
62     mShutdownCond.notify_all();
63 
64     // Invalidate the socket (will kick the RX thread out of recv if it's
65     // currently blocked there)
66     if (mSockFd != INVALID_SOCKET && shutdown(mSockFd, SHUT_RDWR) != 0) {
67       LOG_ERROR("Couldn't shut down socket", errno);
68     }
69 
70     if (mRxThread.joinable()) {
71       LOGD("Waiting for RX thread to exit");
72       mRxThread.join();
73     }
74   }
75 }
76 
isConnected() const77 bool SocketClient::isConnected() const {
78   return (mSockFd != INVALID_SOCKET);
79 }
80 
sendMessage(const void * data,size_t length)81 bool SocketClient::sendMessage(const void *data, size_t length) {
82   bool success = false;
83 
84   if (mSockFd == INVALID_SOCKET) {
85     LOGW("Tried sending a message, but don't have a valid socket handle");
86   } else {
87     ssize_t bytesSent = send(mSockFd, data, length, 0);
88     if (bytesSent < 0) {
89       LOGE("Failed to send %zu bytes of data: %s", length, strerror(errno));
90     } else if (bytesSent == 0) {
91       LOGW("Failed to send data; remote side disconnected");
92     } else if (static_cast<size_t>(bytesSent) != length) {
93       LOGW("Truncated packet, tried sending %zu bytes, only %zd went through",
94            length, bytesSent);
95     } else {
96       success = true;
97     }
98   }
99 
100   return success;
101 }
102 
doConnect(const char * socketName,const sp<ICallbacks> & callbacks,bool connectInBackground)103 bool SocketClient::doConnect(const char *socketName,
104                              const sp<ICallbacks>& callbacks,
105                              bool connectInBackground) {
106   bool success = false;
107   if (inReceiveThread()) {
108     LOGE("Can't attempt to connect from a receive thread callback");
109   } else {
110     if (receiveThreadRunning()) {
111       LOGW("Re-connecting socket with implicit disconnect");
112       disconnect();
113     }
114 
115     size_t socketNameLen = strlcpy(mSocketName, socketName,
116                                    sizeof(mSocketName));
117     if (socketNameLen >= sizeof(mSocketName)) {
118       LOGE("Socket name length parameter is too long (%zu, max %zu)",
119            socketNameLen, sizeof(mSocketName));
120     } else if (callbacks == nullptr) {
121       LOGE("Callbacks parameter must be provided");
122     } else if (connectInBackground || tryConnect()) {
123       mGracefulShutdown = false;
124       mCallbacks = callbacks;
125       mRxThread = std::thread([this]() {
126         receiveThread();
127       });
128       success = true;
129     }
130   }
131 
132   return success;
133 }
134 
inReceiveThread() const135 bool SocketClient::inReceiveThread() const {
136   return (std::this_thread::get_id() == mRxThread.get_id());
137 }
138 
receiveThread()139 void SocketClient::receiveThread() {
140   constexpr size_t kReceiveBufferSize = 4096;
141   uint8_t buffer[kReceiveBufferSize];
142 
143   LOGV("Receive thread started");
144   while (!mGracefulShutdown && (mSockFd != INVALID_SOCKET || reconnect())) {
145     while (!mGracefulShutdown) {
146       ssize_t bytesReceived = recv(mSockFd, buffer, sizeof(buffer), 0);
147       if (bytesReceived < 0) {
148         LOG_ERROR("Exiting RX thread", errno);
149         break;
150       } else if (bytesReceived == 0) {
151         if (!mGracefulShutdown) {
152           LOGI("Socket disconnected on remote end");
153           mCallbacks->onDisconnected();
154         }
155         break;
156       }
157 
158       mCallbacks->onMessageReceived(buffer, bytesReceived);
159     }
160 
161     if (close(mSockFd) != 0) {
162       LOG_ERROR("Couldn't close socket", errno);
163     }
164     mSockFd = INVALID_SOCKET;
165   }
166 
167   if (!mGracefulShutdown) {
168     mCallbacks->onConnectionAborted();
169   }
170 
171   mCallbacks.clear();
172   LOGV("Exiting receive thread");
173 }
174 
receiveThreadRunning() const175 bool SocketClient::receiveThreadRunning() const {
176   return mRxThread.joinable();
177 }
178 
reconnect()179 bool SocketClient::reconnect() {
180   auto delay = std::chrono::duration<int32_t, std::milli>(500);
181   constexpr auto kMaxDelay = std::chrono::minutes(5);
182   int retryLimit = 40;  // ~2.5 hours total
183 
184   while (--retryLimit > 0) {
185     {
186       std::unique_lock<std::mutex> lock(mShutdownMutex);
187       mShutdownCond.wait_for(lock, delay,
188                              [this]() { return mGracefulShutdown.load(); });
189       if (mGracefulShutdown) {
190         break;
191       }
192     }
193 
194     if (!tryConnect()) {
195       LOGW("Failed to (re)connect, next try in %" PRId32 " ms", delay.count());
196       delay *= 2;
197       if (delay > kMaxDelay) {
198         delay = kMaxDelay;
199       }
200     } else {
201       LOGD("Successfully (re)connected");
202       mCallbacks->onConnected();
203       return true;
204     }
205   }
206 
207   return false;
208 }
209 
tryConnect()210 bool SocketClient::tryConnect() {
211   errno = 0;
212   mSockFd = socket_local_client(mSocketName,
213                                 ANDROID_SOCKET_NAMESPACE_RESERVED,
214                                 SOCK_SEQPACKET);
215   if (mSockFd == INVALID_SOCKET) {
216     LOGE("Couldn't create/connect client socket to '%s': %s",
217          mSocketName, strerror(errno));
218   }
219 
220   return (mSockFd != INVALID_SOCKET);
221 }
222 
223 }  // namespace chre
224 }  // namespace android
225