1 /*
2  * Copyright (C) 2008 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
20 import static com.android.testutils.MiscAsserts.assertThrows;
21 
22 import static org.junit.Assert.assertArrayEquals;
23 import static org.junit.Assert.assertEquals;
24 import static org.junit.Assert.assertFalse;
25 import static org.junit.Assert.assertNull;
26 import static org.junit.Assert.assertTrue;
27 import static org.junit.Assert.fail;
28 
29 import android.net.Credentials;
30 import android.net.LocalServerSocket;
31 import android.net.LocalSocket;
32 import android.net.LocalSocketAddress;
33 import android.system.Os;
34 import android.system.OsConstants;
35 import android.system.StructTimeval;
36 
37 import androidx.test.runner.AndroidJUnit4;
38 
39 import com.android.testutils.DevSdkIgnoreRule;
40 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
41 
42 import java.io.FileDescriptor;
43 import java.io.IOException;
44 import java.io.InputStream;
45 import java.io.OutputStream;
46 import java.net.SocketAddress;
47 import java.util.Arrays;
48 import java.util.Random;
49 import java.util.concurrent.Callable;
50 import java.util.concurrent.CountDownLatch;
51 import java.util.concurrent.ExecutorService;
52 import java.util.concurrent.Executors;
53 import java.util.concurrent.Future;
54 import java.util.concurrent.TimeUnit;
55 
56 import org.junit.Rule;
57 import org.junit.Test;
58 import org.junit.runner.RunWith;
59 
60 @RunWith(AndroidJUnit4.class)
61 public class LocalSocketTest {
62     private final static String ADDRESS_PREFIX = "com.android.net.LocalSocketTest";
63 
64     @Rule
65     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
66 
67     @Test
testLocalConnections()68     public void testLocalConnections() throws IOException {
69         String address = ADDRESS_PREFIX + "_testLocalConnections";
70         // create client and server socket
71         LocalServerSocket localServerSocket = new LocalServerSocket(address);
72         LocalSocket clientSocket = new LocalSocket();
73 
74         // establish connection between client and server
75         LocalSocketAddress locSockAddr = new LocalSocketAddress(address);
76         assertFalse(clientSocket.isConnected());
77         clientSocket.connect(locSockAddr);
78         assertTrue(clientSocket.isConnected());
79 
80         LocalSocket serverSocket = localServerSocket.accept();
81         assertTrue(serverSocket.isConnected());
82         assertTrue(serverSocket.isBound());
83         assertThrows(IOException.class, () -> {
84             serverSocket.bind(localServerSocket.getLocalSocketAddress());
85         });
86         assertThrows(IOException.class, () -> {
87             serverSocket.connect(locSockAddr);
88         });
89 
90         Credentials credent = clientSocket.getPeerCredentials();
91         assertTrue(0 != credent.getPid());
92 
93         // send data from client to server
94         OutputStream clientOutStream = clientSocket.getOutputStream();
95         clientOutStream.write(12);
96         InputStream serverInStream = serverSocket.getInputStream();
97         assertEquals(12, serverInStream.read());
98 
99         //send data from server to client
100         OutputStream serverOutStream = serverSocket.getOutputStream();
101         serverOutStream.write(3);
102         InputStream clientInStream = clientSocket.getInputStream();
103         assertEquals(3, clientInStream.read());
104 
105         // Test sending and receiving file descriptors
106         clientSocket.setFileDescriptorsForSend(new FileDescriptor[]{FileDescriptor.in});
107         clientOutStream.write(32);
108         assertEquals(32, serverInStream.read());
109 
110         FileDescriptor[] out = serverSocket.getAncillaryFileDescriptors();
111         assertEquals(1, out.length);
112         FileDescriptor fd = clientSocket.getFileDescriptor();
113         assertTrue(fd.valid());
114 
115         //shutdown input stream of client
116         clientSocket.shutdownInput();
117         assertEquals(-1, clientInStream.read());
118 
119         //shutdown output stream of client
120         clientSocket.shutdownOutput();
121         assertThrows(IOException.class, () -> {
122             clientOutStream.write(10);
123         });
124 
125         //shutdown input stream of server
126         serverSocket.shutdownInput();
127         assertEquals(-1, serverInStream.read());
128 
129         //shutdown output stream of server
130         serverSocket.shutdownOutput();
131         assertThrows(IOException.class, () -> {
132             serverOutStream.write(10);
133         });
134 
135         //close client socket
136         clientSocket.close();
137         assertThrows(IOException.class, () -> {
138             clientInStream.read();
139         });
140 
141         //close server socket
142         serverSocket.close();
143         assertThrows(IOException.class, () -> {
144             serverInStream.read();
145         });
146     }
147 
148     @Test
testAccessors()149     public void testAccessors() throws IOException {
150         String address = ADDRESS_PREFIX + "_testAccessors";
151         LocalSocket socket = new LocalSocket();
152         LocalSocketAddress addr = new LocalSocketAddress(address);
153 
154         assertFalse(socket.isBound());
155         socket.bind(addr);
156         assertTrue(socket.isBound());
157         assertEquals(addr, socket.getLocalSocketAddress());
158 
159         String str = socket.toString();
160         assertTrue(str.contains("impl:android.net.LocalSocketImpl"));
161 
162         socket.setReceiveBufferSize(1999);
163         assertEquals(1999 << 1, socket.getReceiveBufferSize());
164 
165         socket.setSendBufferSize(3998);
166         assertEquals(3998 << 1, socket.getSendBufferSize());
167 
168         assertEquals(0, socket.getSoTimeout());
169         socket.setSoTimeout(1996);
170         assertTrue(socket.getSoTimeout() > 0);
171 
172         assertThrows(UnsupportedOperationException.class, () -> {
173             socket.getRemoteSocketAddress();
174         });
175 
176         assertThrows(UnsupportedOperationException.class, () -> {
177             socket.isClosed();
178         });
179 
180         assertThrows(UnsupportedOperationException.class, () -> {
181             socket.isInputShutdown();
182         });
183 
184         assertThrows(UnsupportedOperationException.class, () -> {
185             socket.isOutputShutdown();
186         });
187 
188         assertThrows(UnsupportedOperationException.class, () -> {
189             socket.connect(addr, 2005);
190         });
191 
192         socket.close();
193     }
194 
195     // http://b/31205169
196     @Test @IgnoreUpTo(SC_V2)  // Crashes on pre-T due to a JNI bug. See http://r.android.com/2096720
testSetSoTimeout_readTimeout()197     public void testSetSoTimeout_readTimeout() throws Exception {
198         String address = ADDRESS_PREFIX + "_testSetSoTimeout_readTimeout";
199 
200         try (LocalSocketPair socketPair = LocalSocketPair.createConnectedSocketPair(address)) {
201             final LocalSocket clientSocket = socketPair.clientSocket;
202 
203             // Set the timeout in millis.
204             int timeoutMillis = 1000;
205             clientSocket.setSoTimeout(timeoutMillis);
206 
207             // Avoid blocking the test run if timeout doesn't happen by using a separate thread.
208             Callable<Result> reader = () -> {
209                 try {
210                     clientSocket.getInputStream().read();
211                     return Result.noException("Did not block");
212                 } catch (IOException e) {
213                     return Result.exception(e);
214                 }
215             };
216             // Allow the configured timeout, plus some slop.
217             int allowedTime = timeoutMillis + 2000;
218             Result result = runInSeparateThread(allowedTime, reader);
219 
220             // Check the message was a timeout, it's all we have to go on.
221             String expectedMessage = Os.strerror(OsConstants.EAGAIN);
222             result.assertThrewIOException(expectedMessage);
223         }
224     }
225 
226     // http://b/31205169
227     @Test
testSetSoTimeout_writeTimeout()228     public void testSetSoTimeout_writeTimeout() throws Exception {
229         String address = ADDRESS_PREFIX + "_testSetSoTimeout_writeTimeout";
230 
231         try (LocalSocketPair socketPair = LocalSocketPair.createConnectedSocketPair(address)) {
232             final LocalSocket clientSocket = socketPair.clientSocket;
233 
234             // Set the timeout in millis.
235             int timeoutMillis = 1000;
236             clientSocket.setSoTimeout(timeoutMillis);
237 
238             // Set a small buffer size so we know we can flood it.
239             clientSocket.setSendBufferSize(100);
240             final int bufferSize = clientSocket.getSendBufferSize();
241 
242             // Avoid blocking the test run if timeout doesn't happen by using a separate thread.
243             Callable<Result> writer = () -> {
244                 try {
245                     byte[] toWrite = new byte[bufferSize * 2];
246                     clientSocket.getOutputStream().write(toWrite);
247                     return Result.noException("Did not block");
248                 } catch (IOException e) {
249                     return Result.exception(e);
250                 }
251             };
252             // Allow the configured timeout, plus some slop.
253             int allowedTime = timeoutMillis + 2000;
254 
255             Result result = runInSeparateThread(allowedTime, writer);
256 
257             // Check the message was a timeout, it's all we have to go on.
258             String expectedMessage = Os.strerror(OsConstants.EAGAIN);
259             result.assertThrewIOException(expectedMessage);
260         }
261     }
262 
263     @Test
testAvailable()264     public void testAvailable() throws Exception {
265         String address = ADDRESS_PREFIX + "_testAvailable";
266 
267         try (LocalSocketPair socketPair = LocalSocketPair.createConnectedSocketPair(address)) {
268             LocalSocket clientSocket = socketPair.clientSocket;
269             LocalSocket serverSocket = socketPair.serverSocket.accept();
270 
271             OutputStream clientOutputStream = clientSocket.getOutputStream();
272             InputStream serverInputStream = serverSocket.getInputStream();
273             assertEquals(0, serverInputStream.available());
274 
275             byte[] buffer = new byte[50];
276             clientOutputStream.write(buffer);
277             assertEquals(50, serverInputStream.available());
278 
279             InputStream clientInputStream = clientSocket.getInputStream();
280             OutputStream serverOutputStream = serverSocket.getOutputStream();
281             assertEquals(0, clientInputStream.available());
282             serverOutputStream.write(buffer);
283             assertEquals(50, serverInputStream.available());
284 
285             serverSocket.close();
286         }
287     }
288 
289     // http://b/34095140
290     @Test @IgnoreUpTo(SC_V2)
testLocalSocketCreatedFromFileDescriptor()291     public void testLocalSocketCreatedFromFileDescriptor() throws Exception {
292         String address = ADDRESS_PREFIX + "_testLocalSocketCreatedFromFileDescriptor";
293 
294         // Establish connection between a local client and server to get a valid client socket file
295         // descriptor.
296         try (LocalSocketPair socketPair = LocalSocketPair.createConnectedSocketPair(address)) {
297             // Extract the client FileDescriptor we can use.
298             FileDescriptor fileDescriptor = socketPair.clientSocket.getFileDescriptor();
299             assertTrue(fileDescriptor.valid());
300 
301             // Create the LocalSocket we want to test.
302             LocalSocket clientSocketCreatedFromFileDescriptor = new LocalSocket(fileDescriptor);
303             assertTrue(clientSocketCreatedFromFileDescriptor.isConnected());
304             assertTrue(clientSocketCreatedFromFileDescriptor.isBound());
305 
306             // Test the LocalSocket can be used for communication.
307             LocalSocket serverSocket = socketPair.serverSocket.accept();
308             OutputStream clientOutputStream =
309                     clientSocketCreatedFromFileDescriptor.getOutputStream();
310             InputStream serverInputStream = serverSocket.getInputStream();
311 
312             clientOutputStream.write(12);
313             assertEquals(12, serverInputStream.read());
314 
315             // Closing clientSocketCreatedFromFileDescriptor does not close the file descriptor.
316             clientSocketCreatedFromFileDescriptor.close();
317             assertTrue(fileDescriptor.valid());
318 
319             // .. while closing the LocalSocket that owned the file descriptor does.
320             socketPair.clientSocket.close();
321             assertFalse(fileDescriptor.valid());
322         }
323     }
324 
325     @Test
testFlush()326     public void testFlush() throws Exception {
327         String address = ADDRESS_PREFIX + "_testFlush";
328 
329         try (LocalSocketPair socketPair = LocalSocketPair.createConnectedSocketPair(address)) {
330             LocalSocket clientSocket = socketPair.clientSocket;
331             LocalSocket serverSocket = socketPair.serverSocket.accept();
332 
333             OutputStream clientOutputStream = clientSocket.getOutputStream();
334             InputStream serverInputStream = serverSocket.getInputStream();
335             testFlushWorks(clientOutputStream, serverInputStream);
336 
337             OutputStream serverOutputStream = serverSocket.getOutputStream();
338             InputStream clientInputStream = clientSocket.getInputStream();
339             testFlushWorks(serverOutputStream, clientInputStream);
340 
341             serverSocket.close();
342         }
343     }
344 
testFlushWorks(OutputStream outputStream, InputStream inputStream)345     private void testFlushWorks(OutputStream outputStream, InputStream inputStream)
346             throws Exception {
347         final int bytesToTransfer = 50;
348         StreamReader inputStreamReader = new StreamReader(inputStream, bytesToTransfer);
349 
350         byte[] buffer = new byte[bytesToTransfer];
351         outputStream.write(buffer);
352         assertEquals(bytesToTransfer, inputStream.available());
353 
354         // Start consuming the data.
355         inputStreamReader.start();
356 
357         // This doesn't actually flush any buffers, it just polls until the reader has read all the
358         // bytes.
359         outputStream.flush();
360 
361         inputStreamReader.waitForCompletion(5000);
362         inputStreamReader.assertBytesRead(bytesToTransfer);
363         assertEquals(0, inputStream.available());
364     }
365 
sendAndReceiveBytes(LocalSocket s1, LocalSocket s2)366     private void sendAndReceiveBytes(LocalSocket s1, LocalSocket s2) throws Exception {
367         final Random random = new Random();
368         final byte[] sendBytes = new byte[random.nextInt(511) + 1];  // Avoid 0-byte writes.
369         random.nextBytes(sendBytes);
370         final int numBytes = sendBytes.length;
371         final OutputStream os = s1.getOutputStream();
372         os.write(sendBytes);
373         os.flush();
374 
375         final InputStream is = s2.getInputStream();
376         final byte[] recvBytes = new byte[1024];
377         assertEquals(numBytes, is.read(recvBytes, 0, recvBytes.length));
378 
379         final byte[] received = Arrays.copyOfRange(recvBytes, 0, numBytes);
380         assertArrayEquals(received, sendBytes);
381     }
382 
383     /**
384      * Keeps track of the highest-numbered FD that is passed in.
385      */
386     private class MaxFdTracker{
387         private int mMax = -1;
388 
get()389         public int get() {
390             return mMax;
391         }
392 
noteFd(int fd)393         private void noteFd(int fd) {
394             mMax = Math.max(mMax, fd);
395         }
396 
noteFd(FileDescriptor fd)397         public void noteFd(FileDescriptor fd) {
398             noteFd(fd.getInt$());
399         }
400 
noteFd(LocalSocket s)401         public void noteFd(LocalSocket s) {
402             noteFd(s.getFileDescriptor().getInt$());
403         }
404     }
405 
406     @Test @IgnoreUpTo(SC_V2)
testCreateFromFd()407     public void testCreateFromFd() throws Exception {
408         String address = ADDRESS_PREFIX + "_testClosingConnectedSocket";
409         LocalServerSocket server = new LocalServerSocket(address);
410 
411         final int TIMEOUT_MS = 1000;
412 
413         final int NUM_ITERATIONS = 1000;
414         int firstFd = -1;
415         MaxFdTracker maxFd = new MaxFdTracker();
416 
417         for (int i = 0; i < NUM_ITERATIONS; i++) {
418             FileDescriptor fd = Os.socket(OsConstants.AF_UNIX, OsConstants.SOCK_STREAM, 0);
419             if (firstFd == -1) {
420                 firstFd = fd.getInt$();
421             } else  {
422                 maxFd.noteFd(fd);
423             }
424 
425             // Ensure the test doesn't hang by setting a reasonably short timeout.
426             // This seems easier than polling on non-blocking socket.
427             Os.setsockoptTimeval(fd, OsConstants.SOL_SOCKET, OsConstants.SO_RCVTIMEO,
428                     StructTimeval.fromMillis(TIMEOUT_MS));
429             Os.setsockoptTimeval(fd, OsConstants.SOL_SOCKET, OsConstants.SO_SNDTIMEO,
430                     StructTimeval.fromMillis(TIMEOUT_MS));
431 
432             final SocketAddress sockAddr = Os.getsockname(server.getFileDescriptor());
433             Os.connect(fd, sockAddr);
434 
435             LocalSocket accepted = server.accept();
436             accepted.setSoTimeout(TIMEOUT_MS);
437             maxFd.noteFd(accepted);
438 
439             LocalSocket ls = new LocalSocket(fd);
440             assertEquals(ls.getFileDescriptor().getInt$(), fd.getInt$());
441             maxFd.noteFd(ls);
442 
443             sendAndReceiveBytes(accepted, ls);
444             sendAndReceiveBytes(ls, accepted);
445 
446             accepted.close();
447             assertNull(accepted.getFileDescriptor());
448             Os.close(fd);
449         }
450         server.close();
451 
452         assertTrue("No FDs created!", firstFd != -1);
453         assertTrue("Only one FD created?", maxFd.get() != -1);
454         int fdsConsumed = maxFd.get() - firstFd;
455         assertTrue(
456                 "FD leak! Opened " + NUM_ITERATIONS + " sockets, FD int went up by " + fdsConsumed,
457             fdsConsumed < NUM_ITERATIONS / 2);
458     }
459 
460     @Test @IgnoreUpTo(SC_V2)
testCreateFromFd_notConnected()461     public void testCreateFromFd_notConnected() throws Exception {
462         FileDescriptor fd = Os.socket(OsConstants.AF_UNIX, OsConstants.SOCK_STREAM, 0);
463         assertThrows(IllegalArgumentException.class, () -> {
464             LocalSocket ls = new LocalSocket(fd);
465         });
466     }
467 
468     @Test @IgnoreUpTo(SC_V2)
testCreateFromFd_notSocket()469     public void testCreateFromFd_notSocket() throws Exception {
470         FileDescriptor fd = Os.open("/dev/null", 0 /* flags */, OsConstants.O_WRONLY);
471         assertThrows(IllegalArgumentException.class, () -> {
472             LocalSocket ls = new LocalSocket(fd);
473         });
474     }
475 
476     private static class StreamReader extends Thread {
477         private final InputStream is;
478         private final int expectedByteCount;
479         private final CountDownLatch completeLatch = new CountDownLatch(1);
480 
481         private volatile Exception exception;
482         private int bytesRead;
483 
StreamReader(InputStream is, int expectedByteCount)484         private StreamReader(InputStream is, int expectedByteCount) {
485             this.is = is;
486             this.expectedByteCount = expectedByteCount;
487         }
488 
489         @Override
run()490         public void run() {
491             try {
492                 byte[] buffer = new byte[10];
493                 int readCount;
494                 while ((readCount = is.read(buffer)) >= 0) {
495                     bytesRead += readCount;
496                     if (bytesRead >= expectedByteCount) {
497                         break;
498                     }
499                 }
500             } catch (IOException e) {
501                 exception = e;
502             } finally {
503                 completeLatch.countDown();
504             }
505         }
506 
waitForCompletion(long waitMillis)507         public void waitForCompletion(long waitMillis) throws Exception {
508             if (!completeLatch.await(waitMillis, TimeUnit.MILLISECONDS)) {
509                 fail("Timeout waiting for completion");
510             }
511             if (exception != null) {
512                 throw new Exception("Read failed", exception);
513             }
514         }
515 
assertBytesRead(int expected)516         public void assertBytesRead(int expected) {
517             assertEquals(expected, bytesRead);
518         }
519     }
520 
521     private static class Result {
522         private final String type;
523         private final Exception e;
524 
Result(String type, Exception e)525         private Result(String type, Exception e) {
526             this.type = type;
527             this.e = e;
528         }
529 
noException(String description)530         static Result noException(String description) {
531             return new Result(description, null);
532         }
533 
exception(Exception e)534         static Result exception(Exception e) {
535             return new Result(e.getClass().getName(), e);
536         }
537 
assertThrewIOException(String expectedMessage)538         void assertThrewIOException(String expectedMessage) {
539             assertEquals("Unexpected result type", IOException.class.getName(), type);
540             assertEquals("Unexpected exception message", expectedMessage, e.getMessage());
541         }
542     }
543 
runInSeparateThread(int allowedTime, final Callable<Result> callable)544     private static Result runInSeparateThread(int allowedTime, final Callable<Result> callable)
545             throws Exception {
546         ExecutorService service = Executors.newSingleThreadScheduledExecutor();
547         Future<Result> future = service.submit(callable);
548         Result result = future.get(allowedTime, TimeUnit.MILLISECONDS);
549         if (!future.isDone()) {
550             fail("Worker thread appears blocked");
551         }
552         return result;
553     }
554 
555     private static class LocalSocketPair implements AutoCloseable {
createConnectedSocketPair(String address)556         static LocalSocketPair createConnectedSocketPair(String address) throws Exception {
557             LocalServerSocket localServerSocket = new LocalServerSocket(address);
558             final LocalSocket clientSocket = new LocalSocket();
559 
560             // Establish connection between client and server
561             LocalSocketAddress locSockAddr = new LocalSocketAddress(address);
562             clientSocket.connect(locSockAddr);
563             assertTrue(clientSocket.isConnected());
564             return new LocalSocketPair(localServerSocket, clientSocket);
565         }
566 
567         final LocalServerSocket serverSocket;
568         final LocalSocket clientSocket;
569 
LocalSocketPair(LocalServerSocket serverSocket, LocalSocket clientSocket)570         LocalSocketPair(LocalServerSocket serverSocket, LocalSocket clientSocket) {
571             this.serverSocket = serverSocket;
572             this.clientSocket = clientSocket;
573         }
574 
close()575         public void close() throws Exception {
576             serverSocket.close();
577             clientSocket.close();
578         }
579     }
580 }
581