1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 
21 import android.content.Context;
22 import android.net.IpSecAlgorithm;
23 import android.net.IpSecManager;
24 import android.net.IpSecTransform;
25 import android.system.Os;
26 import android.system.OsConstants;
27 import android.test.AndroidTestCase;
28 import android.util.Log;
29 
30 import java.io.FileDescriptor;
31 import java.io.IOException;
32 import java.net.DatagramPacket;
33 import java.net.DatagramSocket;
34 import java.net.Inet4Address;
35 import java.net.Inet6Address;
36 import java.net.InetAddress;
37 import java.net.InetSocketAddress;
38 import java.net.ServerSocket;
39 import java.net.Socket;
40 import java.net.SocketException;
41 import java.util.Arrays;
42 import java.util.concurrent.atomic.AtomicInteger;
43 
44 public class IpSecBaseTest extends AndroidTestCase {
45 
46     private static final String TAG = IpSecBaseTest.class.getSimpleName();
47 
48     protected static final String IPV4_LOOPBACK = "127.0.0.1";
49     protected static final String IPV6_LOOPBACK = "::1";
50     protected static final String[] LOOPBACK_ADDRS = new String[] {IPV4_LOOPBACK, IPV6_LOOPBACK};
51     protected static final int[] DIRECTIONS =
52             new int[] {IpSecManager.DIRECTION_IN, IpSecManager.DIRECTION_OUT};
53 
54     protected static final byte[] TEST_DATA = "Best test data ever!".getBytes();
55     protected static final int DATA_BUFFER_LEN = 4096;
56     protected static final int SOCK_TIMEOUT = 500;
57 
58     private static final byte[] KEY_DATA = {
59         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
60         0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
61         0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
62         0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
63         0x20, 0x21, 0x22, 0x23
64     };
65 
66     protected static final byte[] AUTH_KEY = getKey(256);
67     protected static final byte[] CRYPT_KEY = getKey(256);
68 
69     protected IpSecManager mISM;
70 
setUp()71     protected void setUp() throws Exception {
72         super.setUp();
73         mISM = (IpSecManager) getContext().getSystemService(Context.IPSEC_SERVICE);
74     }
75 
getKey(int bitLength)76     protected static byte[] getKey(int bitLength) {
77         return Arrays.copyOf(KEY_DATA, bitLength / 8);
78     }
79 
getDomain(InetAddress address)80     protected static int getDomain(InetAddress address) {
81         int domain;
82         if (address instanceof Inet6Address) {
83             domain = OsConstants.AF_INET6;
84         } else {
85             domain = OsConstants.AF_INET;
86         }
87         return domain;
88     }
89 
getPort(FileDescriptor sock)90     protected static int getPort(FileDescriptor sock) throws Exception {
91         return ((InetSocketAddress) Os.getsockname(sock)).getPort();
92     }
93 
94     public static interface GenericSocket extends AutoCloseable {
send(byte[] data)95         void send(byte[] data) throws Exception;
96 
receive()97         byte[] receive() throws Exception;
98 
getPort()99         int getPort() throws Exception;
100 
close()101         void close() throws Exception;
102 
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)103         void applyTransportModeTransform(
104                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception;
105 
removeTransportModeTransforms(IpSecManager ism)106         void removeTransportModeTransforms(IpSecManager ism) throws Exception;
107     }
108 
109     public static interface GenericTcpSocket extends GenericSocket {}
110 
111     public static interface GenericUdpSocket extends GenericSocket {
sendTo(byte[] data, InetAddress dstAddr, int port)112         void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception;
113     }
114 
115     public abstract static class NativeSocket implements GenericSocket {
116         public FileDescriptor mFd;
117 
NativeSocket(FileDescriptor fd)118         public NativeSocket(FileDescriptor fd) {
119             mFd = fd;
120         }
121 
122         @Override
send(byte[] data)123         public void send(byte[] data) throws Exception {
124             Os.write(mFd, data, 0, data.length);
125         }
126 
127         @Override
receive()128         public byte[] receive() throws Exception {
129             byte[] in = new byte[DATA_BUFFER_LEN];
130             AtomicInteger bytesRead = new AtomicInteger(-1);
131 
132             Thread readSockThread = new Thread(() -> {
133                 long startTime = System.currentTimeMillis();
134                 while (bytesRead.get() < 0 && System.currentTimeMillis() < startTime + SOCK_TIMEOUT) {
135                     try {
136                         bytesRead.set(Os.recvfrom(mFd, in, 0, DATA_BUFFER_LEN, 0, null));
137                     } catch (Exception e) {
138                         Log.e(TAG, "Error encountered reading from socket", e);
139                     }
140                 }
141             });
142 
143             readSockThread.start();
144             readSockThread.join(SOCK_TIMEOUT);
145 
146             if (bytesRead.get() < 0) {
147                 throw new IOException("No data received from socket");
148             }
149 
150             return Arrays.copyOfRange(in, 0, bytesRead.get());
151         }
152 
153         @Override
getPort()154         public int getPort() throws Exception {
155             return IpSecBaseTest.getPort(mFd);
156         }
157 
158         @Override
close()159         public void close() throws Exception {
160             Os.close(mFd);
161         }
162 
163         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)164         public void applyTransportModeTransform(
165                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
166             ism.applyTransportModeTransform(mFd, direction, transform);
167         }
168 
169         @Override
removeTransportModeTransforms(IpSecManager ism)170         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
171             ism.removeTransportModeTransforms(mFd);
172         }
173     }
174 
175     public static class NativeTcpSocket extends NativeSocket implements GenericTcpSocket {
NativeTcpSocket(FileDescriptor fd)176         public NativeTcpSocket(FileDescriptor fd) {
177             super(fd);
178         }
179     }
180 
181     public static class NativeUdpSocket extends NativeSocket implements GenericUdpSocket {
NativeUdpSocket(FileDescriptor fd)182         public NativeUdpSocket(FileDescriptor fd) {
183             super(fd);
184         }
185 
186         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)187         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
188             Os.sendto(mFd, data, 0, data.length, 0, dstAddr, port);
189         }
190     }
191 
192     public static class JavaUdpSocket implements GenericUdpSocket {
193         public final DatagramSocket mSocket;
194 
JavaUdpSocket(InetAddress localAddr)195         public JavaUdpSocket(InetAddress localAddr) {
196             try {
197                 mSocket = new DatagramSocket(0, localAddr);
198                 mSocket.setSoTimeout(SOCK_TIMEOUT);
199             } catch (SocketException e) {
200                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
201                 // could easily end up in an endless wait.
202                 throw new RuntimeException(e);
203             }
204         }
205 
206         @Override
send(byte[] data)207         public void send(byte[] data) throws Exception {
208             mSocket.send(new DatagramPacket(data, data.length));
209         }
210 
211         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)212         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
213             mSocket.send(new DatagramPacket(data, data.length, dstAddr, port));
214         }
215 
216         @Override
getPort()217         public int getPort() throws Exception {
218             return mSocket.getLocalPort();
219         }
220 
221         @Override
close()222         public void close() throws Exception {
223             mSocket.close();
224         }
225 
226         @Override
receive()227         public byte[] receive() throws Exception {
228             DatagramPacket data = new DatagramPacket(new byte[DATA_BUFFER_LEN], DATA_BUFFER_LEN);
229             mSocket.receive(data);
230             return Arrays.copyOfRange(data.getData(), 0, data.getLength());
231         }
232 
233         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)234         public void applyTransportModeTransform(
235                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
236             ism.applyTransportModeTransform(mSocket, direction, transform);
237         }
238 
239         @Override
removeTransportModeTransforms(IpSecManager ism)240         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
241             ism.removeTransportModeTransforms(mSocket);
242         }
243     }
244 
245     public static class JavaTcpSocket implements GenericTcpSocket {
246         public final Socket mSocket;
247 
JavaTcpSocket(Socket socket)248         public JavaTcpSocket(Socket socket) {
249             mSocket = socket;
250             try {
251                 mSocket.setSoTimeout(SOCK_TIMEOUT);
252             } catch (SocketException e) {
253                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
254                 // could easily end up in an endless wait.
255                 throw new RuntimeException(e);
256             }
257         }
258 
259         @Override
send(byte[] data)260         public void send(byte[] data) throws Exception {
261             mSocket.getOutputStream().write(data);
262         }
263 
264         @Override
receive()265         public byte[] receive() throws Exception {
266             byte[] in = new byte[DATA_BUFFER_LEN];
267             int bytesRead = mSocket.getInputStream().read(in);
268             return Arrays.copyOfRange(in, 0, bytesRead);
269         }
270 
271         @Override
getPort()272         public int getPort() throws Exception {
273             return mSocket.getLocalPort();
274         }
275 
276         @Override
close()277         public void close() throws Exception {
278             mSocket.close();
279         }
280 
281         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)282         public void applyTransportModeTransform(
283                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
284             ism.applyTransportModeTransform(mSocket, direction, transform);
285         }
286 
287         @Override
removeTransportModeTransforms(IpSecManager ism)288         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
289             ism.removeTransportModeTransforms(mSocket);
290         }
291     }
292 
293     public static class SocketPair<T> {
294         public final T mLeftSock;
295         public final T mRightSock;
296 
SocketPair(T leftSock, T rightSock)297         public SocketPair(T leftSock, T rightSock) {
298             mLeftSock = leftSock;
299             mRightSock = rightSock;
300         }
301     }
302 
applyTransformBidirectionally( IpSecManager ism, IpSecTransform transform, GenericSocket socket)303     protected static void applyTransformBidirectionally(
304             IpSecManager ism, IpSecTransform transform, GenericSocket socket) throws Exception {
305         for (int direction : DIRECTIONS) {
306             socket.applyTransportModeTransform(ism, direction, transform);
307         }
308     }
309 
getNativeUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)310     public static SocketPair<NativeUdpSocket> getNativeUdpSocketPair(
311             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
312             throws Exception {
313         int domain = getDomain(localAddr);
314 
315         NativeUdpSocket leftSock = new NativeUdpSocket(
316             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
317         NativeUdpSocket rightSock = new NativeUdpSocket(
318             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
319 
320         for (NativeUdpSocket sock : new NativeUdpSocket[] {leftSock, rightSock}) {
321             applyTransformBidirectionally(ism, transform, sock);
322             Os.bind(sock.mFd, localAddr, 0);
323         }
324 
325         if (connected) {
326             Os.connect(leftSock.mFd, localAddr, rightSock.getPort());
327             Os.connect(rightSock.mFd, localAddr, leftSock.getPort());
328         }
329 
330         return new SocketPair<>(leftSock, rightSock);
331     }
332 
getNativeTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)333     public static SocketPair<NativeTcpSocket> getNativeTcpSocketPair(
334             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
335         int domain = getDomain(localAddr);
336 
337         NativeTcpSocket server = new NativeTcpSocket(
338                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
339         NativeTcpSocket client = new NativeTcpSocket(
340                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
341 
342         Os.bind(server.mFd, localAddr, 0);
343 
344         applyTransformBidirectionally(ism, transform, server);
345         applyTransformBidirectionally(ism, transform, client);
346 
347         Os.listen(server.mFd, 10);
348         Os.connect(client.mFd, localAddr, server.getPort());
349         NativeTcpSocket accepted = new NativeTcpSocket(Os.accept(server.mFd, null));
350 
351         applyTransformBidirectionally(ism, transform, accepted);
352         server.close();
353 
354         return new SocketPair<>(client, accepted);
355     }
356 
getJavaUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)357     public static SocketPair<JavaUdpSocket> getJavaUdpSocketPair(
358             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
359             throws Exception {
360         JavaUdpSocket leftSock = new JavaUdpSocket(localAddr);
361         JavaUdpSocket rightSock = new JavaUdpSocket(localAddr);
362 
363         applyTransformBidirectionally(ism, transform, leftSock);
364         applyTransformBidirectionally(ism, transform, rightSock);
365 
366         if (connected) {
367             leftSock.mSocket.connect(localAddr, rightSock.mSocket.getLocalPort());
368             rightSock.mSocket.connect(localAddr, leftSock.mSocket.getLocalPort());
369         }
370 
371         return new SocketPair<>(leftSock, rightSock);
372     }
373 
getJavaTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)374     public static SocketPair<JavaTcpSocket> getJavaTcpSocketPair(
375             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
376         JavaTcpSocket clientSock = new JavaTcpSocket(new Socket());
377         ServerSocket serverSocket = new ServerSocket();
378         serverSocket.bind(new InetSocketAddress(localAddr, 0));
379 
380         // While technically the client socket does not need to be bound, the OpenJDK implementation
381         // of Socket only allocates an FD when bind() or connect() or other similar methods are
382         // called. So we call bind to force the FD creation, so that we can apply a transform to it
383         // prior to socket connect.
384         clientSock.mSocket.bind(new InetSocketAddress(localAddr, 0));
385 
386         // IpSecService doesn't support serverSockets at the moment; workaround using FD
387         FileDescriptor serverFd = serverSocket.getImpl().getFD$();
388 
389         applyTransformBidirectionally(ism, transform, new NativeTcpSocket(serverFd));
390         applyTransformBidirectionally(ism, transform, clientSock);
391 
392         clientSock.mSocket.connect(new InetSocketAddress(localAddr, serverSocket.getLocalPort()));
393         JavaTcpSocket acceptedSock = new JavaTcpSocket(serverSocket.accept());
394 
395         applyTransformBidirectionally(ism, transform, acceptedSock);
396         serverSocket.close();
397 
398         return new SocketPair<>(clientSock, acceptedSock);
399     }
400 
checkSocketPair(GenericSocket left, GenericSocket right)401     private void checkSocketPair(GenericSocket left, GenericSocket right) throws Exception {
402         left.send(TEST_DATA);
403         assertArrayEquals(TEST_DATA, right.receive());
404 
405         right.send(TEST_DATA);
406         assertArrayEquals(TEST_DATA, left.receive());
407 
408         left.close();
409         right.close();
410     }
411 
checkUnconnectedUdpSocketPair( GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr)412     private void checkUnconnectedUdpSocketPair(
413             GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr) throws Exception {
414         left.sendTo(TEST_DATA, localAddr, right.getPort());
415         assertArrayEquals(TEST_DATA, right.receive());
416 
417         right.sendTo(TEST_DATA, localAddr, left.getPort());
418         assertArrayEquals(TEST_DATA, left.receive());
419 
420         left.close();
421         right.close();
422     }
423 
buildIpSecTransform( Context mContext, IpSecManager.SecurityParameterIndex spi, IpSecManager.UdpEncapsulationSocket encapSocket, InetAddress remoteAddr)424     protected static IpSecTransform buildIpSecTransform(
425             Context mContext,
426             IpSecManager.SecurityParameterIndex spi,
427             IpSecManager.UdpEncapsulationSocket encapSocket,
428             InetAddress remoteAddr)
429             throws Exception {
430         String localAddr = (remoteAddr instanceof Inet4Address) ? IPV4_LOOPBACK : IPV6_LOOPBACK;
431         IpSecTransform.Builder builder =
432                 new IpSecTransform.Builder(mContext)
433                 .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
434                 .setAuthentication(
435                         new IpSecAlgorithm(
436                                 IpSecAlgorithm.AUTH_HMAC_SHA256,
437                                 AUTH_KEY,
438                                 AUTH_KEY.length * 4));
439 
440         if (encapSocket != null) {
441             builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
442         }
443 
444         return builder.buildTransportModeTransform(InetAddress.getByName(localAddr), spi);
445     }
446 
buildDefaultTransform(InetAddress localAddr)447     private IpSecTransform buildDefaultTransform(InetAddress localAddr) throws Exception {
448         try (IpSecManager.SecurityParameterIndex spi =
449                 mISM.allocateSecurityParameterIndex(localAddr)) {
450             return buildIpSecTransform(mContext, spi, null, localAddr);
451         }
452     }
453 
testJavaTcpSocketPair()454     public void testJavaTcpSocketPair() throws Exception {
455         for (String addr : LOOPBACK_ADDRS) {
456             InetAddress local = InetAddress.getByName(addr);
457             try (IpSecTransform transform = buildDefaultTransform(local)) {
458                 SocketPair<JavaTcpSocket> sockets = getJavaTcpSocketPair(local, mISM, transform);
459                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
460             }
461         }
462     }
463 
testJavaUdpSocketPair()464     public void testJavaUdpSocketPair() throws Exception {
465         for (String addr : LOOPBACK_ADDRS) {
466             InetAddress local = InetAddress.getByName(addr);
467             try (IpSecTransform transform = buildDefaultTransform(local)) {
468                 SocketPair<JavaUdpSocket> sockets =
469                         getJavaUdpSocketPair(local, mISM, transform, true);
470                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
471             }
472         }
473     }
474 
testJavaUdpSocketPairUnconnected()475     public void testJavaUdpSocketPairUnconnected() throws Exception {
476         for (String addr : LOOPBACK_ADDRS) {
477             InetAddress local = InetAddress.getByName(addr);
478             try (IpSecTransform transform = buildDefaultTransform(local)) {
479                 SocketPair<JavaUdpSocket> sockets =
480                         getJavaUdpSocketPair(local, mISM, transform, false);
481                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
482             }
483         }
484     }
485 
testNativeTcpSocketPair()486     public void testNativeTcpSocketPair() throws Exception {
487         for (String addr : LOOPBACK_ADDRS) {
488             InetAddress local = InetAddress.getByName(addr);
489             try (IpSecTransform transform = buildDefaultTransform(local)) {
490                 SocketPair<NativeTcpSocket> sockets =
491                         getNativeTcpSocketPair(local, mISM, transform);
492                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
493             }
494         }
495     }
496 
testNativeUdpSocketPair()497     public void testNativeUdpSocketPair() throws Exception {
498         for (String addr : LOOPBACK_ADDRS) {
499             InetAddress local = InetAddress.getByName(addr);
500             try (IpSecTransform transform = buildDefaultTransform(local)) {
501                 SocketPair<NativeUdpSocket> sockets =
502                         getNativeUdpSocketPair(local, mISM, transform, true);
503                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
504             }
505         }
506     }
507 
testNativeUdpSocketPairUnconnected()508     public void testNativeUdpSocketPairUnconnected() throws Exception {
509         for (String addr : LOOPBACK_ADDRS) {
510             InetAddress local = InetAddress.getByName(addr);
511             try (IpSecTransform transform = buildDefaultTransform(local)) {
512                 SocketPair<NativeUdpSocket> sockets =
513                         getNativeUdpSocketPair(local, mISM, transform, false);
514                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
515             }
516         }
517     }
518 }
519