1 #include <alloca.h>
2 #include <errno.h>
3 #include <malloc.h>
4 #include <pthread.h>
5 #include <signal.h>
6 #include <string.h>
7 #include <arpa/inet.h>
8 #include <sys/socket.h>
9 #include <sys/types.h>
10
11 #define LOG_TAG "SocketClient"
12 #include <cutils/log.h>
13
14 #include <sysutils/SocketClient.h>
15
SocketClient(int socket,bool owned)16 SocketClient::SocketClient(int socket, bool owned) {
17 init(socket, owned, false);
18 }
19
SocketClient(int socket,bool owned,bool useCmdNum)20 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
21 init(socket, owned, useCmdNum);
22 }
23
init(int socket,bool owned,bool useCmdNum)24 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
25 mSocket = socket;
26 mSocketOwned = owned;
27 mUseCmdNum = useCmdNum;
28 pthread_mutex_init(&mWriteMutex, NULL);
29 pthread_mutex_init(&mRefCountMutex, NULL);
30 mPid = -1;
31 mUid = -1;
32 mGid = -1;
33 mRefCount = 1;
34 mCmdNum = 0;
35
36 struct ucred creds;
37 socklen_t szCreds = sizeof(creds);
38 memset(&creds, 0, szCreds);
39
40 int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
41 if (err == 0) {
42 mPid = creds.pid;
43 mUid = creds.uid;
44 mGid = creds.gid;
45 }
46 }
47
~SocketClient()48 SocketClient::~SocketClient() {
49 if (mSocketOwned) {
50 close(mSocket);
51 }
52 }
53
sendMsg(int code,const char * msg,bool addErrno)54 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
55 return sendMsg(code, msg, addErrno, mUseCmdNum);
56 }
57
sendMsg(int code,const char * msg,bool addErrno,bool useCmdNum)58 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
59 char *buf;
60 int ret = 0;
61
62 if (addErrno) {
63 if (useCmdNum) {
64 ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
65 } else {
66 ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
67 }
68 } else {
69 if (useCmdNum) {
70 ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
71 } else {
72 ret = asprintf(&buf, "%d %s", code, msg);
73 }
74 }
75 // Send the zero-terminated message
76 if (ret != -1) {
77 ret = sendMsg(buf);
78 free(buf);
79 }
80 return ret;
81 }
82
83 // send 3-digit code, null, binary-length, binary data
sendBinaryMsg(int code,const void * data,int len)84 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
85
86 // 4 bytes for the code & null + 4 bytes for the len
87 char buf[8];
88 // Write the code
89 snprintf(buf, 4, "%.3d", code);
90 // Write the len
91 uint32_t tmp = htonl(len);
92 memcpy(buf + 4, &tmp, sizeof(uint32_t));
93
94 struct iovec vec[2];
95 vec[0].iov_base = (void *) buf;
96 vec[0].iov_len = sizeof(buf);
97 vec[1].iov_base = (void *) data;
98 vec[1].iov_len = len;
99
100 pthread_mutex_lock(&mWriteMutex);
101 int result = sendDataLockedv(vec, (len > 0) ? 2 : 1);
102 pthread_mutex_unlock(&mWriteMutex);
103
104 return result;
105 }
106
107 // Sends the code (c-string null-terminated).
sendCode(int code)108 int SocketClient::sendCode(int code) {
109 char buf[4];
110 snprintf(buf, sizeof(buf), "%.3d", code);
111 return sendData(buf, sizeof(buf));
112 }
113
quoteArg(const char * arg)114 char *SocketClient::quoteArg(const char *arg) {
115 int len = strlen(arg);
116 char *result = (char *)malloc(len * 2 + 3);
117 char *current = result;
118 const char *end = arg + len;
119 char *oldresult;
120
121 if(result == NULL) {
122 SLOGW("malloc error (%s)", strerror(errno));
123 return NULL;
124 }
125
126 *(current++) = '"';
127 while (arg < end) {
128 switch (*arg) {
129 case '\\':
130 case '"':
131 *(current++) = '\\'; // fallthrough
132 default:
133 *(current++) = *(arg++);
134 }
135 }
136 *(current++) = '"';
137 *(current++) = '\0';
138 oldresult = result; // save pointer in case realloc fails
139 result = (char *)realloc(result, current-result);
140 return result ? result : oldresult;
141 }
142
143
sendMsg(const char * msg)144 int SocketClient::sendMsg(const char *msg) {
145 // Send the message including null character
146 if (sendData(msg, strlen(msg) + 1) != 0) {
147 SLOGW("Unable to send msg '%s'", msg);
148 return -1;
149 }
150 return 0;
151 }
152
sendData(const void * data,int len)153 int SocketClient::sendData(const void *data, int len) {
154 struct iovec vec[1];
155 vec[0].iov_base = (void *) data;
156 vec[0].iov_len = len;
157
158 pthread_mutex_lock(&mWriteMutex);
159 int rc = sendDataLockedv(vec, 1);
160 pthread_mutex_unlock(&mWriteMutex);
161
162 return rc;
163 }
164
sendDatav(struct iovec * iov,int iovcnt)165 int SocketClient::sendDatav(struct iovec *iov, int iovcnt) {
166 pthread_mutex_lock(&mWriteMutex);
167 int rc = sendDataLockedv(iov, iovcnt);
168 pthread_mutex_unlock(&mWriteMutex);
169
170 return rc;
171 }
172
sendDataLockedv(struct iovec * iov,int iovcnt)173 int SocketClient::sendDataLockedv(struct iovec *iov, int iovcnt) {
174
175 if (mSocket < 0) {
176 errno = EHOSTUNREACH;
177 return -1;
178 }
179
180 if (iovcnt <= 0) {
181 return 0;
182 }
183
184 int ret = 0;
185 int e = 0; // SLOGW and sigaction are not inert regarding errno
186 int current = 0;
187
188 struct sigaction new_action, old_action;
189 memset(&new_action, 0, sizeof(new_action));
190 new_action.sa_handler = SIG_IGN;
191 sigaction(SIGPIPE, &new_action, &old_action);
192
193 for (;;) {
194 ssize_t rc = TEMP_FAILURE_RETRY(
195 writev(mSocket, iov + current, iovcnt - current));
196
197 if (rc > 0) {
198 size_t written = rc;
199 while ((current < iovcnt) && (written >= iov[current].iov_len)) {
200 written -= iov[current].iov_len;
201 current++;
202 }
203 if (current == iovcnt) {
204 break;
205 }
206 iov[current].iov_base = (char *)iov[current].iov_base + written;
207 iov[current].iov_len -= written;
208 continue;
209 }
210
211 if (rc == 0) {
212 e = EIO;
213 SLOGW("0 length write :(");
214 } else {
215 e = errno;
216 SLOGW("write error (%s)", strerror(e));
217 }
218 ret = -1;
219 break;
220 }
221
222 sigaction(SIGPIPE, &old_action, &new_action);
223
224 if (e != 0) {
225 errno = e;
226 }
227 return ret;
228 }
229
incRef()230 void SocketClient::incRef() {
231 pthread_mutex_lock(&mRefCountMutex);
232 mRefCount++;
233 pthread_mutex_unlock(&mRefCountMutex);
234 }
235
decRef()236 bool SocketClient::decRef() {
237 bool deleteSelf = false;
238 pthread_mutex_lock(&mRefCountMutex);
239 mRefCount--;
240 if (mRefCount == 0) {
241 deleteSelf = true;
242 } else if (mRefCount < 0) {
243 SLOGE("SocketClient refcount went negative!");
244 }
245 pthread_mutex_unlock(&mRefCountMutex);
246 if (deleteSelf) {
247 delete this;
248 }
249 return deleteSelf;
250 }
251