1 #include <alloca.h>
2 #include <errno.h>
3 #include <malloc.h>
4 #include <pthread.h>
5 #include <signal.h>
6 #include <string.h>
7 #include <arpa/inet.h>
8 #include <sys/socket.h>
9 #include <sys/types.h>
10 
11 #define LOG_TAG "SocketClient"
12 #include <cutils/log.h>
13 
14 #include <sysutils/SocketClient.h>
15 
SocketClient(int socket,bool owned)16 SocketClient::SocketClient(int socket, bool owned) {
17     init(socket, owned, false);
18 }
19 
SocketClient(int socket,bool owned,bool useCmdNum)20 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
21     init(socket, owned, useCmdNum);
22 }
23 
init(int socket,bool owned,bool useCmdNum)24 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
25     mSocket = socket;
26     mSocketOwned = owned;
27     mUseCmdNum = useCmdNum;
28     pthread_mutex_init(&mWriteMutex, NULL);
29     pthread_mutex_init(&mRefCountMutex, NULL);
30     mPid = -1;
31     mUid = -1;
32     mGid = -1;
33     mRefCount = 1;
34     mCmdNum = 0;
35 
36     struct ucred creds;
37     socklen_t szCreds = sizeof(creds);
38     memset(&creds, 0, szCreds);
39 
40     int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
41     if (err == 0) {
42         mPid = creds.pid;
43         mUid = creds.uid;
44         mGid = creds.gid;
45     }
46 }
47 
~SocketClient()48 SocketClient::~SocketClient() {
49     if (mSocketOwned) {
50         close(mSocket);
51     }
52 }
53 
sendMsg(int code,const char * msg,bool addErrno)54 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
55     return sendMsg(code, msg, addErrno, mUseCmdNum);
56 }
57 
sendMsg(int code,const char * msg,bool addErrno,bool useCmdNum)58 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
59     char *buf;
60     int ret = 0;
61 
62     if (addErrno) {
63         if (useCmdNum) {
64             ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
65         } else {
66             ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
67         }
68     } else {
69         if (useCmdNum) {
70             ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
71         } else {
72             ret = asprintf(&buf, "%d %s", code, msg);
73         }
74     }
75     // Send the zero-terminated message
76     if (ret != -1) {
77         ret = sendMsg(buf);
78         free(buf);
79     }
80     return ret;
81 }
82 
83 // send 3-digit code, null, binary-length, binary data
sendBinaryMsg(int code,const void * data,int len)84 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
85 
86     // 4 bytes for the code & null + 4 bytes for the len
87     char buf[8];
88     // Write the code
89     snprintf(buf, 4, "%.3d", code);
90     // Write the len
91     uint32_t tmp = htonl(len);
92     memcpy(buf + 4, &tmp, sizeof(uint32_t));
93 
94     struct iovec vec[2];
95     vec[0].iov_base = (void *) buf;
96     vec[0].iov_len = sizeof(buf);
97     vec[1].iov_base = (void *) data;
98     vec[1].iov_len = len;
99 
100     pthread_mutex_lock(&mWriteMutex);
101     int result = sendDataLockedv(vec, (len > 0) ? 2 : 1);
102     pthread_mutex_unlock(&mWriteMutex);
103 
104     return result;
105 }
106 
107 // Sends the code (c-string null-terminated).
sendCode(int code)108 int SocketClient::sendCode(int code) {
109     char buf[4];
110     snprintf(buf, sizeof(buf), "%.3d", code);
111     return sendData(buf, sizeof(buf));
112 }
113 
quoteArg(const char * arg)114 char *SocketClient::quoteArg(const char *arg) {
115     int len = strlen(arg);
116     char *result = (char *)malloc(len * 2 + 3);
117     char *current = result;
118     const char *end = arg + len;
119     char *oldresult;
120 
121     if(result == NULL) {
122         SLOGW("malloc error (%s)", strerror(errno));
123         return NULL;
124     }
125 
126     *(current++) = '"';
127     while (arg < end) {
128         switch (*arg) {
129         case '\\':
130         case '"':
131             *(current++) = '\\'; // fallthrough
132         default:
133             *(current++) = *(arg++);
134         }
135     }
136     *(current++) = '"';
137     *(current++) = '\0';
138     oldresult = result; // save pointer in case realloc fails
139     result = (char *)realloc(result, current-result);
140     return result ? result : oldresult;
141 }
142 
143 
sendMsg(const char * msg)144 int SocketClient::sendMsg(const char *msg) {
145     // Send the message including null character
146     if (sendData(msg, strlen(msg) + 1) != 0) {
147         SLOGW("Unable to send msg '%s'", msg);
148         return -1;
149     }
150     return 0;
151 }
152 
sendData(const void * data,int len)153 int SocketClient::sendData(const void *data, int len) {
154     struct iovec vec[1];
155     vec[0].iov_base = (void *) data;
156     vec[0].iov_len = len;
157 
158     pthread_mutex_lock(&mWriteMutex);
159     int rc = sendDataLockedv(vec, 1);
160     pthread_mutex_unlock(&mWriteMutex);
161 
162     return rc;
163 }
164 
sendDatav(struct iovec * iov,int iovcnt)165 int SocketClient::sendDatav(struct iovec *iov, int iovcnt) {
166     pthread_mutex_lock(&mWriteMutex);
167     int rc = sendDataLockedv(iov, iovcnt);
168     pthread_mutex_unlock(&mWriteMutex);
169 
170     return rc;
171 }
172 
sendDataLockedv(struct iovec * iov,int iovcnt)173 int SocketClient::sendDataLockedv(struct iovec *iov, int iovcnt) {
174 
175     if (mSocket < 0) {
176         errno = EHOSTUNREACH;
177         return -1;
178     }
179 
180     if (iovcnt <= 0) {
181         return 0;
182     }
183 
184     int ret = 0;
185     int e = 0; // SLOGW and sigaction are not inert regarding errno
186     int current = 0;
187 
188     struct sigaction new_action, old_action;
189     memset(&new_action, 0, sizeof(new_action));
190     new_action.sa_handler = SIG_IGN;
191     sigaction(SIGPIPE, &new_action, &old_action);
192 
193     for (;;) {
194         ssize_t rc = TEMP_FAILURE_RETRY(
195             writev(mSocket, iov + current, iovcnt - current));
196 
197         if (rc > 0) {
198             size_t written = rc;
199             while ((current < iovcnt) && (written >= iov[current].iov_len)) {
200                 written -= iov[current].iov_len;
201                 current++;
202             }
203             if (current == iovcnt) {
204                 break;
205             }
206             iov[current].iov_base = (char *)iov[current].iov_base + written;
207             iov[current].iov_len -= written;
208             continue;
209         }
210 
211         if (rc == 0) {
212             e = EIO;
213             SLOGW("0 length write :(");
214         } else {
215             e = errno;
216             SLOGW("write error (%s)", strerror(e));
217         }
218         ret = -1;
219         break;
220     }
221 
222     sigaction(SIGPIPE, &old_action, &new_action);
223 
224     if (e != 0) {
225         errno = e;
226     }
227     return ret;
228 }
229 
incRef()230 void SocketClient::incRef() {
231     pthread_mutex_lock(&mRefCountMutex);
232     mRefCount++;
233     pthread_mutex_unlock(&mRefCountMutex);
234 }
235 
decRef()236 bool SocketClient::decRef() {
237     bool deleteSelf = false;
238     pthread_mutex_lock(&mRefCountMutex);
239     mRefCount--;
240     if (mRefCount == 0) {
241         deleteSelf = true;
242     } else if (mRefCount < 0) {
243         SLOGE("SocketClient refcount went negative!");
244     }
245     pthread_mutex_unlock(&mRefCountMutex);
246     if (deleteSelf) {
247         delete this;
248     }
249     return deleteSelf;
250 }
251