1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.thread.utils;
17 
18 import static android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK;
19 import static android.system.OsConstants.IPPROTO_ICMPV6;
20 
21 import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
22 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_PIO;
23 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
24 
25 import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
26 
27 import static java.util.concurrent.TimeUnit.MILLISECONDS;
28 import static java.util.concurrent.TimeUnit.SECONDS;
29 
30 import android.net.ConnectivityManager;
31 import android.net.InetAddresses;
32 import android.net.LinkAddress;
33 import android.net.Network;
34 import android.net.NetworkCapabilities;
35 import android.net.NetworkRequest;
36 import android.net.TestNetworkInterface;
37 import android.net.nsd.NsdManager;
38 import android.net.nsd.NsdServiceInfo;
39 import android.net.thread.ThreadNetworkController;
40 import android.os.Build;
41 import android.os.Handler;
42 import android.os.SystemClock;
43 
44 import androidx.annotation.NonNull;
45 import androidx.test.core.app.ApplicationProvider;
46 
47 import com.android.net.module.util.Struct;
48 import com.android.net.module.util.structs.Icmpv6Header;
49 import com.android.net.module.util.structs.Ipv6Header;
50 import com.android.net.module.util.structs.PrefixInformationOption;
51 import com.android.net.module.util.structs.RaHeader;
52 import com.android.testutils.HandlerUtils;
53 import com.android.testutils.TapPacketReader;
54 
55 import com.google.common.util.concurrent.SettableFuture;
56 
57 import java.io.FileDescriptor;
58 import java.io.IOException;
59 import java.net.DatagramPacket;
60 import java.net.DatagramSocket;
61 import java.net.Inet6Address;
62 import java.net.InetAddress;
63 import java.net.InetSocketAddress;
64 import java.net.SocketAddress;
65 import java.nio.ByteBuffer;
66 import java.time.Duration;
67 import java.util.ArrayList;
68 import java.util.List;
69 import java.util.concurrent.CompletableFuture;
70 import java.util.concurrent.ExecutionException;
71 import java.util.concurrent.TimeUnit;
72 import java.util.concurrent.TimeoutException;
73 import java.util.function.Predicate;
74 import java.util.function.Supplier;
75 
76 /** Static utility methods relating to Thread integration tests. */
77 public final class IntegrationTestUtils {
78     // The timeout of join() after restarting ot-daemon. The device needs to send 6 Link Request
79     // every 5 seconds, followed by 4 Parent Request every second. So this value needs to be 40
80     // seconds to be safe
81     public static final Duration RESTART_JOIN_TIMEOUT = Duration.ofSeconds(40);
82     public static final Duration JOIN_TIMEOUT = Duration.ofSeconds(30);
83     public static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
84     public static final Duration CALLBACK_TIMEOUT = Duration.ofSeconds(1);
85     public static final Duration SERVICE_DISCOVERY_TIMEOUT = Duration.ofSeconds(20);
86 
IntegrationTestUtils()87     private IntegrationTestUtils() {}
88 
89     /**
90      * Waits for the given {@link Supplier} to be true until given timeout.
91      *
92      * @param condition the condition to check
93      * @param timeout the time to wait for the condition before throwing
94      * @throws TimeoutException if the condition is still not met when the timeout expires
95      */
waitFor(Supplier<Boolean> condition, Duration timeout)96     public static void waitFor(Supplier<Boolean> condition, Duration timeout)
97             throws TimeoutException {
98         final long intervalMills = 500;
99         final long timeoutMills = timeout.toMillis();
100 
101         for (long i = 0; i < timeoutMills; i += intervalMills) {
102             if (condition.get()) {
103                 return;
104             }
105             SystemClock.sleep(intervalMills);
106         }
107         if (condition.get()) {
108             return;
109         }
110         throw new TimeoutException("The condition failed to become true in " + timeout);
111     }
112 
113     /**
114      * Creates a {@link TapPacketReader} given the {@link TestNetworkInterface} and {@link Handler}.
115      *
116      * @param testNetworkInterface the TUN interface of the test network
117      * @param handler the handler to process the packets
118      * @return the {@link TapPacketReader}
119      */
newPacketReader( TestNetworkInterface testNetworkInterface, Handler handler)120     public static TapPacketReader newPacketReader(
121             TestNetworkInterface testNetworkInterface, Handler handler) {
122         FileDescriptor fd = testNetworkInterface.getFileDescriptor().getFileDescriptor();
123         final TapPacketReader reader =
124                 new TapPacketReader(handler, fd, testNetworkInterface.getMtu());
125         handler.post(() -> reader.start());
126         HandlerUtils.waitForIdle(handler, 5000 /* timeout in milliseconds */);
127         return reader;
128     }
129 
130     /**
131      * Waits for the Thread module to enter any state of the given {@code deviceRoles}.
132      *
133      * @param controller the {@link ThreadNetworkController}
134      * @param deviceRoles the desired device roles. See also {@link
135      *     ThreadNetworkController.DeviceRole}
136      * @param timeout the time to wait for the expected state before throwing
137      * @return the {@link ThreadNetworkController.DeviceRole} after waiting
138      * @throws TimeoutException if the device hasn't become any of expected roles until the timeout
139      *     expires
140      */
waitForStateAnyOf( ThreadNetworkController controller, List<Integer> deviceRoles, Duration timeout)141     public static int waitForStateAnyOf(
142             ThreadNetworkController controller, List<Integer> deviceRoles, Duration timeout)
143             throws TimeoutException {
144         SettableFuture<Integer> future = SettableFuture.create();
145         ThreadNetworkController.StateCallback callback =
146                 newRole -> {
147                     if (deviceRoles.contains(newRole)) {
148                         future.set(newRole);
149                     }
150                 };
151         controller.registerStateCallback(directExecutor(), callback);
152         try {
153             return future.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
154         } catch (InterruptedException | ExecutionException e) {
155             throw new TimeoutException(
156                     String.format(
157                             "The device didn't become an expected role in %s: %s",
158                             timeout, e.getMessage()));
159         } finally {
160             controller.unregisterStateCallback(callback);
161         }
162     }
163 
164     /**
165      * Polls for a packet from a given {@link TapPacketReader} that satisfies the {@code filter}.
166      *
167      * @param packetReader a TUN packet reader
168      * @param filter the filter to be applied on the packet
169      * @return the first IPv6 packet that satisfies the {@code filter}. If it has waited for more
170      *     than 3000ms to read the next packet, the method will return null
171      */
pollForPacket(TapPacketReader packetReader, Predicate<byte[]> filter)172     public static byte[] pollForPacket(TapPacketReader packetReader, Predicate<byte[]> filter) {
173         byte[] packet;
174         while ((packet = packetReader.poll(3000 /* timeoutMs */, filter)) != null) {
175             return packet;
176         }
177         return null;
178     }
179 
180     /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
isExpectedIcmpv6Packet(byte[] packet, int type)181     public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
182         if (packet == null) {
183             return false;
184         }
185         ByteBuffer buf = ByteBuffer.wrap(packet);
186         try {
187             if (Struct.parse(Ipv6Header.class, buf).nextHeader != (byte) IPPROTO_ICMPV6) {
188                 return false;
189             }
190             return Struct.parse(Icmpv6Header.class, buf).type == (short) type;
191         } catch (IllegalArgumentException ignored) {
192             // It's fine that the passed in packet is malformed because it's could be sent
193             // by anybody.
194         }
195         return false;
196     }
197 
isFromIpv6Source(byte[] packet, Inet6Address src)198     public static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
199         if (packet == null) {
200             return false;
201         }
202         ByteBuffer buf = ByteBuffer.wrap(packet);
203         try {
204             return Struct.parse(Ipv6Header.class, buf).srcIp.equals(src);
205         } catch (IllegalArgumentException ignored) {
206             // It's fine that the passed in packet is malformed because it's could be sent
207             // by anybody.
208         }
209         return false;
210     }
211 
isToIpv6Destination(byte[] packet, Inet6Address dest)212     public static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
213         if (packet == null) {
214             return false;
215         }
216         ByteBuffer buf = ByteBuffer.wrap(packet);
217         try {
218             return Struct.parse(Ipv6Header.class, buf).dstIp.equals(dest);
219         } catch (IllegalArgumentException ignored) {
220             // It's fine that the passed in packet is malformed because it's could be sent
221             // by anybody.
222         }
223         return false;
224     }
225 
226     /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
getRaPios(byte[] raMsg)227     public static List<PrefixInformationOption> getRaPios(byte[] raMsg) {
228         final ArrayList<PrefixInformationOption> pioList = new ArrayList<>();
229 
230         if (raMsg == null) {
231             return pioList;
232         }
233 
234         final ByteBuffer buf = ByteBuffer.wrap(raMsg);
235         final Ipv6Header ipv6Header = Struct.parse(Ipv6Header.class, buf);
236         if (ipv6Header.nextHeader != (byte) IPPROTO_ICMPV6) {
237             return pioList;
238         }
239 
240         final Icmpv6Header icmpv6Header = Struct.parse(Icmpv6Header.class, buf);
241         if (icmpv6Header.type != (short) ICMPV6_ROUTER_ADVERTISEMENT) {
242             return pioList;
243         }
244 
245         Struct.parse(RaHeader.class, buf);
246         while (buf.position() < raMsg.length) {
247             final int currentPos = buf.position();
248             final int type = Byte.toUnsignedInt(buf.get());
249             final int length = Byte.toUnsignedInt(buf.get());
250             if (type == ICMPV6_ND_OPTION_PIO) {
251                 final ByteBuffer pioBuf =
252                         ByteBuffer.wrap(
253                                 buf.array(),
254                                 currentPos,
255                                 Struct.getSize(PrefixInformationOption.class));
256                 final PrefixInformationOption pio =
257                         Struct.parse(PrefixInformationOption.class, pioBuf);
258                 pioList.add(pio);
259 
260                 // Move ByteBuffer position to the next option.
261                 buf.position(currentPos + Struct.getSize(PrefixInformationOption.class));
262             } else {
263                 // The length is in units of 8 octets.
264                 buf.position(currentPos + (length * 8));
265             }
266         }
267         return pioList;
268     }
269 
270     /**
271      * Sends a UDP message to a destination.
272      *
273      * @param dstAddress the IP address of the destination
274      * @param dstPort the port of the destination
275      * @param message the message in UDP payload
276      * @throws IOException if failed to send the message
277      */
sendUdpMessage(InetAddress dstAddress, int dstPort, String message)278     public static void sendUdpMessage(InetAddress dstAddress, int dstPort, String message)
279             throws IOException {
280         SocketAddress dstSockAddr = new InetSocketAddress(dstAddress, dstPort);
281 
282         try (DatagramSocket socket = new DatagramSocket()) {
283             socket.connect(dstSockAddr);
284 
285             byte[] msgBytes = message.getBytes();
286             DatagramPacket packet = new DatagramPacket(msgBytes, msgBytes.length);
287 
288             socket.send(packet);
289         }
290     }
291 
isInMulticastGroup(String interfaceName, Inet6Address address)292     public static boolean isInMulticastGroup(String interfaceName, Inet6Address address) {
293         final String cmd = "ip -6 maddr show dev " + interfaceName;
294         final String output = runShellCommandOrThrow(cmd);
295         final String addressStr = address.getHostAddress();
296         for (final String line : output.split("\\n")) {
297             if (line.contains(addressStr)) {
298                 return true;
299             }
300         }
301         return false;
302     }
303 
getIpv6LinkAddresses(String interfaceName)304     public static List<LinkAddress> getIpv6LinkAddresses(String interfaceName) {
305         List<LinkAddress> addresses = new ArrayList<>();
306         final String cmd = " ip -6 addr show dev " + interfaceName;
307         final String output = runShellCommandOrThrow(cmd);
308 
309         for (final String line : output.split("\\n")) {
310             if (line.contains("inet6")) {
311                 addresses.add(parseAddressLine(line));
312             }
313         }
314 
315         return addresses;
316     }
317 
318     /** Return the first discovered service of {@code serviceType}. */
discoverService(NsdManager nsdManager, String serviceType)319     public static NsdServiceInfo discoverService(NsdManager nsdManager, String serviceType)
320             throws Exception {
321         CompletableFuture<NsdServiceInfo> serviceInfoFuture = new CompletableFuture<>();
322         NsdManager.DiscoveryListener listener =
323                 new DefaultDiscoveryListener() {
324                     @Override
325                     public void onServiceFound(NsdServiceInfo serviceInfo) {
326                         serviceInfoFuture.complete(serviceInfo);
327                     }
328                 };
329         nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
330         try {
331             serviceInfoFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
332         } finally {
333             nsdManager.stopServiceDiscovery(listener);
334         }
335 
336         return serviceInfoFuture.get();
337     }
338 
339     /**
340      * Returns the {@link NsdServiceInfo} when a service instance of {@code serviceType} gets lost.
341      */
discoverForServiceLost( NsdManager nsdManager, String serviceType, CompletableFuture<NsdServiceInfo> serviceInfoFuture)342     public static NsdManager.DiscoveryListener discoverForServiceLost(
343             NsdManager nsdManager,
344             String serviceType,
345             CompletableFuture<NsdServiceInfo> serviceInfoFuture) {
346         NsdManager.DiscoveryListener listener =
347                 new DefaultDiscoveryListener() {
348                     @Override
349                     public void onServiceLost(NsdServiceInfo serviceInfo) {
350                         serviceInfoFuture.complete(serviceInfo);
351                     }
352                 };
353         nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
354         return listener;
355     }
356 
357     /** Resolves the service. */
resolveService(NsdManager nsdManager, NsdServiceInfo serviceInfo)358     public static NsdServiceInfo resolveService(NsdManager nsdManager, NsdServiceInfo serviceInfo)
359             throws Exception {
360         return resolveServiceUntil(nsdManager, serviceInfo, s -> true);
361     }
362 
363     /** Returns the first resolved service that satisfies the {@code predicate}. */
resolveServiceUntil( NsdManager nsdManager, NsdServiceInfo serviceInfo, Predicate<NsdServiceInfo> predicate)364     public static NsdServiceInfo resolveServiceUntil(
365             NsdManager nsdManager, NsdServiceInfo serviceInfo, Predicate<NsdServiceInfo> predicate)
366             throws Exception {
367         CompletableFuture<NsdServiceInfo> resolvedServiceInfoFuture = new CompletableFuture<>();
368         NsdManager.ServiceInfoCallback callback =
369                 new DefaultServiceInfoCallback() {
370                     @Override
371                     public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
372                         if (predicate.test(serviceInfo)) {
373                             resolvedServiceInfoFuture.complete(serviceInfo);
374                         }
375                     }
376                 };
377         nsdManager.registerServiceInfoCallback(serviceInfo, directExecutor(), callback);
378         try {
379             return resolvedServiceInfoFuture.get(
380                     SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
381         } finally {
382             nsdManager.unregisterServiceInfoCallback(callback);
383         }
384     }
385 
getPrefixesFromNetData(String netData)386     public static String getPrefixesFromNetData(String netData) {
387         int startIdx = netData.indexOf("Prefixes:");
388         int endIdx = netData.indexOf("Routes:");
389         return netData.substring(startIdx, endIdx);
390     }
391 
getThreadNetwork(Duration timeout)392     public static Network getThreadNetwork(Duration timeout) throws Exception {
393         CompletableFuture<Network> networkFuture = new CompletableFuture<>();
394         ConnectivityManager cm =
395                 ApplicationProvider.getApplicationContext()
396                         .getSystemService(ConnectivityManager.class);
397         NetworkRequest.Builder networkRequestBuilder =
398                 new NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD);
399         // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
400         // a Thread network.
401         if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
402             networkRequestBuilder.addCapability(NET_CAPABILITY_LOCAL_NETWORK);
403         }
404         NetworkRequest networkRequest = networkRequestBuilder.build();
405         ConnectivityManager.NetworkCallback networkCallback =
406                 new ConnectivityManager.NetworkCallback() {
407                     @Override
408                     public void onAvailable(Network network) {
409                         networkFuture.complete(network);
410                     }
411                 };
412         cm.registerNetworkCallback(networkRequest, networkCallback);
413         return networkFuture.get(timeout.toSeconds(), SECONDS);
414     }
415 
416     private static class DefaultDiscoveryListener implements NsdManager.DiscoveryListener {
417         @Override
onStartDiscoveryFailed(String serviceType, int errorCode)418         public void onStartDiscoveryFailed(String serviceType, int errorCode) {}
419 
420         @Override
onStopDiscoveryFailed(String serviceType, int errorCode)421         public void onStopDiscoveryFailed(String serviceType, int errorCode) {}
422 
423         @Override
onDiscoveryStarted(String serviceType)424         public void onDiscoveryStarted(String serviceType) {}
425 
426         @Override
onDiscoveryStopped(String serviceType)427         public void onDiscoveryStopped(String serviceType) {}
428 
429         @Override
onServiceFound(NsdServiceInfo serviceInfo)430         public void onServiceFound(NsdServiceInfo serviceInfo) {}
431 
432         @Override
onServiceLost(NsdServiceInfo serviceInfo)433         public void onServiceLost(NsdServiceInfo serviceInfo) {}
434     }
435 
436     private static class DefaultServiceInfoCallback implements NsdManager.ServiceInfoCallback {
437         @Override
onServiceInfoCallbackRegistrationFailed(int errorCode)438         public void onServiceInfoCallbackRegistrationFailed(int errorCode) {}
439 
440         @Override
onServiceUpdated(@onNull NsdServiceInfo serviceInfo)441         public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {}
442 
443         @Override
onServiceLost()444         public void onServiceLost() {}
445 
446         @Override
onServiceInfoCallbackUnregistered()447         public void onServiceInfoCallbackUnregistered() {}
448     }
449 
450     /**
451      * Parses a line of output from "ip -6 addr show" into a {@link LinkAddress}.
452      *
453      * <p>Example line: "inet6 2001:db8:1:1::1/64 scope global deprecated"
454      */
parseAddressLine(String line)455     private static LinkAddress parseAddressLine(String line) {
456         String[] parts = line.trim().split("\\s+");
457         String addressString = parts[1];
458         String[] pieces = addressString.split("/", 2);
459         int prefixLength = Integer.parseInt(pieces[1]);
460         final InetAddress address = InetAddresses.parseNumericAddress(pieces[0]);
461         long deprecationTimeMillis =
462                 line.contains("deprecated")
463                         ? SystemClock.elapsedRealtime()
464                         : LinkAddress.LIFETIME_PERMANENT;
465 
466         return new LinkAddress(
467                 address,
468                 prefixLength,
469                 0 /* flags */,
470                 0 /* scope */,
471                 deprecationTimeMillis,
472                 LinkAddress.LIFETIME_PERMANENT /* expirationTime */);
473     }
474 }
475