1 /*
2  * Copyright (C) 2006 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "LocalSocketImpl"
18 
19 #include "JNIHelp.h"
20 #include "jni.h"
21 #include "utils/Log.h"
22 #include "utils/misc.h"
23 
24 #include <stdio.h>
25 #include <string.h>
26 #include <sys/types.h>
27 #include <sys/socket.h>
28 #include <sys/un.h>
29 #include <arpa/inet.h>
30 #include <netinet/in.h>
31 #include <stdlib.h>
32 #include <errno.h>
33 #include <unistd.h>
34 #include <sys/ioctl.h>
35 
36 #include <cutils/sockets.h>
37 #include <netinet/tcp.h>
38 #include <ScopedUtfChars.h>
39 
40 namespace android {
41 
42 template <typename T>
UNUSED(T t)43 void UNUSED(T t) {}
44 
45 static jfieldID field_inboundFileDescriptors;
46 static jfieldID field_outboundFileDescriptors;
47 static jclass class_Credentials;
48 static jclass class_FileDescriptor;
49 static jmethodID method_CredentialsInit;
50 
51 /* private native void connectLocal(FileDescriptor fd,
52  * String name, int namespace) throws IOException
53  */
54 static void
socket_connect_local(JNIEnv * env,jobject object,jobject fileDescriptor,jstring name,jint namespaceId)55 socket_connect_local(JNIEnv *env, jobject object,
56                         jobject fileDescriptor, jstring name, jint namespaceId)
57 {
58     int ret;
59     int fd;
60 
61     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
62 
63     if (env->ExceptionCheck()) {
64         return;
65     }
66 
67     ScopedUtfChars nameUtf8(env, name);
68 
69     ret = socket_local_client_connect(
70                 fd,
71                 nameUtf8.c_str(),
72                 namespaceId,
73                 SOCK_STREAM);
74 
75     if (ret < 0) {
76         jniThrowIOException(env, errno);
77         return;
78     }
79 }
80 
81 #define DEFAULT_BACKLOG 4
82 
83 /* private native void bindLocal(FileDescriptor fd, String name, namespace)
84  * throws IOException;
85  */
86 
87 static void
socket_bind_local(JNIEnv * env,jobject object,jobject fileDescriptor,jstring name,jint namespaceId)88 socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
89                 jstring name, jint namespaceId)
90 {
91     int ret;
92     int fd;
93 
94     if (name == NULL) {
95         jniThrowNullPointerException(env, NULL);
96         return;
97     }
98 
99     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
100 
101     if (env->ExceptionCheck()) {
102         return;
103     }
104 
105     ScopedUtfChars nameUtf8(env, name);
106 
107     ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
108 
109     if (ret < 0) {
110         jniThrowIOException(env, errno);
111         return;
112     }
113 }
114 
115 /**
116  * Processes ancillary data, handling only
117  * SCM_RIGHTS. Creates appropriate objects and sets appropriate
118  * fields in the LocalSocketImpl object. Returns 0 on success
119  * or -1 if an exception was thrown.
120  */
socket_process_cmsg(JNIEnv * env,jobject thisJ,struct msghdr * pMsg)121 static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
122 {
123     struct cmsghdr *cmsgptr;
124 
125     for (cmsgptr = CMSG_FIRSTHDR(pMsg);
126             cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
127 
128         if (cmsgptr->cmsg_level != SOL_SOCKET) {
129             continue;
130         }
131 
132         if (cmsgptr->cmsg_type == SCM_RIGHTS) {
133             int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
134             jobjectArray fdArray;
135             int count
136                 = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
137 
138             if (count < 0) {
139                 jniThrowException(env, "java/io/IOException",
140                     "invalid cmsg length");
141                 return -1;
142             }
143 
144             fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
145 
146             if (fdArray == NULL) {
147                 return -1;
148             }
149 
150             for (int i = 0; i < count; i++) {
151                 jobject fdObject
152                         = jniCreateFileDescriptor(env, pDescriptors[i]);
153 
154                 if (env->ExceptionCheck()) {
155                     return -1;
156                 }
157 
158                 env->SetObjectArrayElement(fdArray, i, fdObject);
159 
160                 if (env->ExceptionCheck()) {
161                     return -1;
162                 }
163             }
164 
165             env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
166 
167             if (env->ExceptionCheck()) {
168                 return -1;
169             }
170         }
171     }
172 
173     return 0;
174 }
175 
176 /**
177  * Reads data from a socket into buf, processing any ancillary data
178  * and adding it to thisJ.
179  *
180  * Returns the length of normal data read, or -1 if an exception has
181  * been thrown in this function.
182  */
socket_read_all(JNIEnv * env,jobject thisJ,int fd,void * buffer,size_t len)183 static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
184         void *buffer, size_t len)
185 {
186     ssize_t ret;
187     struct msghdr msg;
188     struct iovec iv;
189     unsigned char *buf = (unsigned char *)buffer;
190     // Enough buffer for a pile of fd's. We throw an exception if
191     // this buffer is too small.
192     struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
193 
194     memset(&msg, 0, sizeof(msg));
195     memset(&iv, 0, sizeof(iv));
196 
197     iv.iov_base = buf;
198     iv.iov_len = len;
199 
200     msg.msg_iov = &iv;
201     msg.msg_iovlen = 1;
202     msg.msg_control = cmsgbuf;
203     msg.msg_controllen = sizeof(cmsgbuf);
204 
205     ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
206 
207     if (ret < 0 && errno == EPIPE) {
208         // Treat this as an end of stream
209         return 0;
210     }
211 
212     if (ret < 0) {
213         jniThrowIOException(env, errno);
214         return -1;
215     }
216 
217     if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
218         // To us, any of the above flags are a fatal error
219 
220         jniThrowException(env, "java/io/IOException",
221                 "Unexpected error or truncation during recvmsg()");
222 
223         return -1;
224     }
225 
226     if (ret >= 0) {
227         socket_process_cmsg(env, thisJ, &msg);
228     }
229 
230     return ret;
231 }
232 
233 /**
234  * Writes all the data in the specified buffer to the specified socket.
235  *
236  * Returns 0 on success or -1 if an exception was thrown.
237  */
socket_write_all(JNIEnv * env,jobject object,int fd,void * buf,size_t len)238 static int socket_write_all(JNIEnv *env, jobject object, int fd,
239         void *buf, size_t len)
240 {
241     ssize_t ret;
242     struct msghdr msg;
243     unsigned char *buffer = (unsigned char *)buf;
244     memset(&msg, 0, sizeof(msg));
245 
246     jobjectArray outboundFds
247             = (jobjectArray)env->GetObjectField(
248                 object, field_outboundFileDescriptors);
249 
250     if (env->ExceptionCheck()) {
251         return -1;
252     }
253 
254     struct cmsghdr *cmsg;
255     int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
256     int fds[countFds];
257     char msgbuf[CMSG_SPACE(countFds)];
258 
259     // Add any pending outbound file descriptors to the message
260     if (outboundFds != NULL) {
261 
262         if (env->ExceptionCheck()) {
263             return -1;
264         }
265 
266         for (int i = 0; i < countFds; i++) {
267             jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
268             if (env->ExceptionCheck()) {
269                 return -1;
270             }
271 
272             fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
273             if (env->ExceptionCheck()) {
274                 return -1;
275             }
276         }
277 
278         // See "man cmsg" really
279         msg.msg_control = msgbuf;
280         msg.msg_controllen = sizeof msgbuf;
281         cmsg = CMSG_FIRSTHDR(&msg);
282         cmsg->cmsg_level = SOL_SOCKET;
283         cmsg->cmsg_type = SCM_RIGHTS;
284         cmsg->cmsg_len = CMSG_LEN(sizeof fds);
285         memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
286     }
287 
288     // We only write our msg_control during the first write
289     while (len > 0) {
290         struct iovec iv;
291         memset(&iv, 0, sizeof(iv));
292 
293         iv.iov_base = buffer;
294         iv.iov_len = len;
295 
296         msg.msg_iov = &iv;
297         msg.msg_iovlen = 1;
298 
299         do {
300             ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
301         } while (ret < 0 && errno == EINTR);
302 
303         if (ret < 0) {
304             jniThrowIOException(env, errno);
305             return -1;
306         }
307 
308         buffer += ret;
309         len -= ret;
310 
311         // Wipes out any msg_control too
312         memset(&msg, 0, sizeof(msg));
313     }
314 
315     return 0;
316 }
317 
socket_read(JNIEnv * env,jobject object,jobject fileDescriptor)318 static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
319 {
320     int fd;
321     int err;
322 
323     if (fileDescriptor == NULL) {
324         jniThrowNullPointerException(env, NULL);
325         return (jint)-1;
326     }
327 
328     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
329 
330     if (env->ExceptionCheck()) {
331         return (jint)0;
332     }
333 
334     unsigned char buf;
335 
336     err = socket_read_all(env, object, fd, &buf, 1);
337 
338     if (err < 0) {
339         jniThrowIOException(env, errno);
340         return (jint)0;
341     }
342 
343     if (err == 0) {
344         // end of file
345         return (jint)-1;
346     }
347 
348     return (jint)buf;
349 }
350 
socket_readba(JNIEnv * env,jobject object,jbyteArray buffer,jint off,jint len,jobject fileDescriptor)351 static jint socket_readba (JNIEnv *env, jobject object,
352         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
353 {
354     int fd;
355     jbyte* byteBuffer;
356     int ret;
357 
358     if (fileDescriptor == NULL || buffer == NULL) {
359         jniThrowNullPointerException(env, NULL);
360         return (jint)-1;
361     }
362 
363     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
364         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
365         return (jint)-1;
366     }
367 
368     if (len == 0) {
369         // because socket_read_all returns 0 on EOF
370         return 0;
371     }
372 
373     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
374 
375     if (env->ExceptionCheck()) {
376         return (jint)-1;
377     }
378 
379     byteBuffer = env->GetByteArrayElements(buffer, NULL);
380 
381     if (NULL == byteBuffer) {
382         // an exception will have been thrown
383         return (jint)-1;
384     }
385 
386     ret = socket_read_all(env, object,
387             fd, byteBuffer + off, len);
388 
389     // A return of -1 above means an exception is pending
390 
391     env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
392 
393     return (jint) ((ret == 0) ? -1 : ret);
394 }
395 
socket_write(JNIEnv * env,jobject object,jint b,jobject fileDescriptor)396 static void socket_write (JNIEnv *env, jobject object,
397         jint b, jobject fileDescriptor)
398 {
399     int fd;
400     int err;
401 
402     if (fileDescriptor == NULL) {
403         jniThrowNullPointerException(env, NULL);
404         return;
405     }
406 
407     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
408 
409     if (env->ExceptionCheck()) {
410         return;
411     }
412 
413     err = socket_write_all(env, object, fd, &b, 1);
414     UNUSED(err);
415     // A return of -1 above means an exception is pending
416 }
417 
socket_writeba(JNIEnv * env,jobject object,jbyteArray buffer,jint off,jint len,jobject fileDescriptor)418 static void socket_writeba (JNIEnv *env, jobject object,
419         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
420 {
421     int fd;
422     int err;
423     jbyte* byteBuffer;
424 
425     if (fileDescriptor == NULL || buffer == NULL) {
426         jniThrowNullPointerException(env, NULL);
427         return;
428     }
429 
430     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
431         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
432         return;
433     }
434 
435     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
436 
437     if (env->ExceptionCheck()) {
438         return;
439     }
440 
441     byteBuffer = env->GetByteArrayElements(buffer,NULL);
442 
443     if (NULL == byteBuffer) {
444         // an exception will have been thrown
445         return;
446     }
447 
448     err = socket_write_all(env, object, fd,
449             byteBuffer + off, len);
450     UNUSED(err);
451     // A return of -1 above means an exception is pending
452 
453     env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
454 }
455 
socket_get_peer_credentials(JNIEnv * env,jobject object,jobject fileDescriptor)456 static jobject socket_get_peer_credentials(JNIEnv *env,
457         jobject object, jobject fileDescriptor)
458 {
459     int err;
460     int fd;
461 
462     if (fileDescriptor == NULL) {
463         jniThrowNullPointerException(env, NULL);
464         return NULL;
465     }
466 
467     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
468 
469     if (env->ExceptionCheck()) {
470         return NULL;
471     }
472 
473     struct ucred creds;
474 
475     memset(&creds, 0, sizeof(creds));
476     socklen_t szCreds = sizeof(creds);
477 
478     err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
479 
480     if (err < 0) {
481         jniThrowIOException(env, errno);
482         return NULL;
483     }
484 
485     if (szCreds == 0) {
486         return NULL;
487     }
488 
489     return env->NewObject(class_Credentials, method_CredentialsInit,
490             creds.pid, creds.uid, creds.gid);
491 }
492 
493 /*
494  * JNI registration.
495  */
496 static const JNINativeMethod gMethods[] = {
497      /* name, signature, funcPtr */
498     {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
499                                                 (void*)socket_connect_local},
500     {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
501     {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
502     {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
503     {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
504     {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
505     {"getPeerCredentials_native",
506             "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
507             (void*) socket_get_peer_credentials}
508 };
509 
register_android_net_LocalSocketImpl(JNIEnv * env)510 int register_android_net_LocalSocketImpl(JNIEnv *env)
511 {
512     jclass clazz;
513 
514     clazz = env->FindClass("android/net/LocalSocketImpl");
515 
516     if (clazz == NULL) {
517         goto error;
518     }
519 
520     field_inboundFileDescriptors = env->GetFieldID(clazz,
521             "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
522 
523     if (field_inboundFileDescriptors == NULL) {
524         goto error;
525     }
526 
527     field_outboundFileDescriptors = env->GetFieldID(clazz,
528             "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
529 
530     if (field_outboundFileDescriptors == NULL) {
531         goto error;
532     }
533 
534     class_Credentials = env->FindClass("android/net/Credentials");
535 
536     if (class_Credentials == NULL) {
537         goto error;
538     }
539 
540     class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
541 
542     class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
543 
544     if (class_FileDescriptor == NULL) {
545         goto error;
546     }
547 
548     class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
549 
550     method_CredentialsInit
551             = env->GetMethodID(class_Credentials, "<init>", "(III)V");
552 
553     if (method_CredentialsInit == NULL) {
554         goto error;
555     }
556 
557     return jniRegisterNativeMethods(env,
558         "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
559 
560 error:
561     ALOGE("Error registering android.net.LocalSocketImpl");
562     return -1;
563 }
564 
565 };
566