1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns.util;
18 
19 import static com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED;
20 
21 import android.annotation.NonNull;
22 import android.annotation.Nullable;
23 import android.net.Network;
24 import android.os.Build;
25 import android.os.Handler;
26 import android.os.SystemClock;
27 import android.util.ArraySet;
28 import android.util.Pair;
29 
30 import com.android.server.connectivity.mdns.MdnsConstants;
31 import com.android.server.connectivity.mdns.MdnsPacket;
32 import com.android.server.connectivity.mdns.MdnsPacketWriter;
33 import com.android.server.connectivity.mdns.MdnsRecord;
34 
35 import java.io.IOException;
36 import java.net.DatagramPacket;
37 import java.net.InetSocketAddress;
38 import java.nio.ByteBuffer;
39 import java.nio.CharBuffer;
40 import java.nio.charset.Charset;
41 import java.nio.charset.CharsetEncoder;
42 import java.nio.charset.StandardCharsets;
43 import java.util.ArrayList;
44 import java.util.Arrays;
45 import java.util.Collections;
46 import java.util.HashSet;
47 import java.util.List;
48 import java.util.Set;
49 
50 /**
51  * Mdns utility functions.
52  */
53 public class MdnsUtils {
54 
MdnsUtils()55     private MdnsUtils() { }
56 
57     /**
58      * Convert the string to DNS case-insensitive lowercase
59      *
60      * Per rfc6762#page-46, accented characters are not defined to be automatically equivalent to
61      * their unaccented counterparts. So the "DNS lowercase" should be if character is A-Z then they
62      * transform into a-z. Otherwise, they are kept as-is.
63      */
toDnsLowerCase(@onNull String string)64     public static String toDnsLowerCase(@NonNull String string) {
65         final char[] outChars = new char[string.length()];
66         for (int i = 0; i < string.length(); i++) {
67             outChars[i] = toDnsLowerCase(string.charAt(i));
68         }
69         return new String(outChars);
70     }
71 
72     /**
73      * Create a ArraySet or HashSet based on the sdk version.
74      */
newSet()75     public static <Type> Set<Type> newSet() {
76         if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
77             return new ArraySet<>();
78         } else {
79             return new HashSet<>();
80         }
81     }
82 
83     /**
84      * Convert the array of labels to DNS case-insensitive lowercase.
85      */
toDnsLabelsLowerCase(@onNull String[] labels)86     public static String[] toDnsLabelsLowerCase(@NonNull String[] labels) {
87         final String[] outStrings = new String[labels.length];
88         for (int i = 0; i < labels.length; ++i) {
89             outStrings[i] = toDnsLowerCase(labels[i]);
90         }
91         return outStrings;
92     }
93 
94     /**
95      * Compare two strings by DNS case-insensitive lowercase.
96      */
equalsIgnoreDnsCase(@ullable String a, @Nullable String b)97     public static boolean equalsIgnoreDnsCase(@Nullable String a, @Nullable String b) {
98         if (a == null || b == null) {
99             return a == null && b == null;
100         }
101         if (a.length() != b.length()) return false;
102         for (int i = 0; i < a.length(); i++) {
103             if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) {
104                 return false;
105             }
106         }
107         return true;
108     }
109 
110     /**
111      * Compare two set of DNS labels by DNS case-insensitive lowercase.
112      */
equalsDnsLabelIgnoreDnsCase(@onNull String[] a, @NonNull String[] b)113     public static boolean equalsDnsLabelIgnoreDnsCase(@NonNull String[] a, @NonNull String[] b) {
114         if (a == b) {
115             return true;
116         }
117         int length = a.length;
118         if (b.length != length) {
119             return false;
120         }
121         for (int i = 0; i < length; i++) {
122             if (!equalsIgnoreDnsCase(a[i], b[i])) {
123                 return false;
124             }
125         }
126         return true;
127     }
128 
129     /**
130      * Compare labels a equals b or a is suffix of b.
131      *
132      * @param a the type or subtype.
133      * @param b the base type
134      */
typeEqualsOrIsSubtype(@onNull String[] a, @NonNull String[] b)135     public static boolean typeEqualsOrIsSubtype(@NonNull String[] a,
136             @NonNull String[] b) {
137         return MdnsUtils.equalsDnsLabelIgnoreDnsCase(a, b)
138                 || ((b.length == (a.length + 2))
139                 && MdnsUtils.equalsIgnoreDnsCase(b[1], MdnsConstants.SUBTYPE_LABEL)
140                 && MdnsRecord.labelsAreSuffix(a, b));
141     }
142 
toDnsLowerCase(char a)143     private static char toDnsLowerCase(char a) {
144         return a >= 'A' && a <= 'Z' ? (char) (a + ('a' - 'A')) : a;
145     }
146 
147     /*** Ensure that current running thread is same as given handler thread */
ensureRunningOnHandlerThread(@onNull Handler handler)148     public static void ensureRunningOnHandlerThread(@NonNull Handler handler) {
149         if (!isRunningOnHandlerThread(handler)) {
150             throw new IllegalStateException(
151                     "Not running on Handler thread: " + Thread.currentThread().getName());
152         }
153     }
154 
155     /*** Check that current running thread is same as given handler thread */
isRunningOnHandlerThread(@onNull Handler handler)156     public static boolean isRunningOnHandlerThread(@NonNull Handler handler) {
157         if (handler.getLooper().getThread() == Thread.currentThread()) {
158             return true;
159         }
160         return false;
161     }
162 
163     /*** Check whether the target network matches the current network */
isNetworkMatched(@ullable Network targetNetwork, @Nullable Network currentNetwork)164     public static boolean isNetworkMatched(@Nullable Network targetNetwork,
165             @Nullable Network currentNetwork) {
166         return targetNetwork == null || targetNetwork.equals(currentNetwork);
167     }
168 
169     /*** Check whether the target network matches any of the current networks */
isAnyNetworkMatched(@ullable Network targetNetwork, Set<Network> currentNetworks)170     public static boolean isAnyNetworkMatched(@Nullable Network targetNetwork,
171             Set<Network> currentNetworks) {
172         if (targetNetwork == null) {
173             return !currentNetworks.isEmpty();
174         }
175         return currentNetworks.contains(targetNetwork);
176     }
177 
178     /**
179      * Truncate a service name to up to maxLength UTF-8 bytes.
180      */
truncateServiceName(@onNull String originalName, int maxLength)181     public static String truncateServiceName(@NonNull String originalName, int maxLength) {
182         // UTF-8 is at most 4 bytes per character; return early in the common case where
183         // the name can't possibly be over the limit given its string length.
184         if (originalName.length() <= maxLength / 4) return originalName;
185 
186         final Charset utf8 = StandardCharsets.UTF_8;
187         final CharsetEncoder encoder = utf8.newEncoder();
188         final ByteBuffer out = ByteBuffer.allocate(maxLength);
189         // encode will write as many characters as possible to the out buffer, and just
190         // return an overflow code if there were too many characters (no need to check the
191         // return code here, this method truncates the name on purpose).
192         encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */);
193         return new String(out.array(), 0, out.position(), utf8);
194     }
195 
196     /**
197      * Write the mdns packet from given MdnsPacket.
198      */
writeMdnsPacket(@onNull MdnsPacketWriter writer, @NonNull MdnsPacket packet)199     public static void writeMdnsPacket(@NonNull MdnsPacketWriter writer, @NonNull MdnsPacket packet)
200             throws IOException {
201         writer.writeUInt16(packet.transactionId); // Transaction ID (advertisement: 0)
202         writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4)
203         writer.writeUInt16(packet.questions.size()); // questions count
204         writer.writeUInt16(packet.answers.size()); // answers count
205         writer.writeUInt16(packet.authorityRecords.size()); // authority entries count
206         writer.writeUInt16(packet.additionalRecords.size()); // additional records count
207 
208         for (MdnsRecord record : packet.questions) {
209             // Questions do not have TTL or data
210             record.writeHeaderFields(writer);
211         }
212         for (MdnsRecord record : packet.answers) {
213             record.write(writer, 0L);
214         }
215         for (MdnsRecord record : packet.authorityRecords) {
216             record.write(writer, 0L);
217         }
218         for (MdnsRecord record : packet.additionalRecords) {
219             record.write(writer, 0L);
220         }
221     }
222 
223     /**
224      * Create a raw DNS packet.
225      */
createRawDnsPacket(@onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet)226     public static byte[] createRawDnsPacket(@NonNull byte[] packetCreationBuffer,
227             @NonNull MdnsPacket packet) throws IOException {
228         // TODO: support packets over size (send in multiple packets with TC bit set)
229         final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
230         writeMdnsPacket(writer, packet);
231 
232         final int len = writer.getWritePosition();
233         return Arrays.copyOfRange(packetCreationBuffer, 0, len);
234     }
235 
236     /**
237      * Writes the possible query content of an MdnsPacket into the data buffer.
238      *
239      * <p>This method is specifically for query packets. It writes the question and answer sections
240      *    into the data buffer only.
241      *
242      * @param packetCreationBuffer The data buffer for the query content.
243      * @param packet The MdnsPacket to be written into the data buffer.
244      * @return A Pair containing:
245      *         1. The remaining MdnsPacket data that could not fit in the buffer.
246      *         2. The length of the data written to the buffer.
247      */
248     @Nullable
writePossibleMdnsPacket( @onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet)249     private static Pair<MdnsPacket, Integer> writePossibleMdnsPacket(
250             @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet) throws IOException {
251         MdnsPacket remainingPacket;
252         final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
253         writer.writeUInt16(packet.transactionId); // Transaction ID
254 
255         final int flagsPos = writer.getWritePosition();
256         writer.writeUInt16(0); // Flags, written later
257         writer.writeUInt16(0); // questions count, written later
258         writer.writeUInt16(0); // answers count, written later
259         writer.writeUInt16(0); // authority entries count, empty session for query
260         writer.writeUInt16(0); // additional records count, empty session for query
261 
262         int writtenQuestions = 0;
263         int writtenAnswers = 0;
264         int lastValidPos = writer.getWritePosition();
265         try {
266             for (MdnsRecord record : packet.questions) {
267                 // Questions do not have TTL or data
268                 record.writeHeaderFields(writer);
269                 writtenQuestions++;
270                 lastValidPos = writer.getWritePosition();
271             }
272             for (MdnsRecord record : packet.answers) {
273                 record.write(writer, 0L);
274                 writtenAnswers++;
275                 lastValidPos = writer.getWritePosition();
276             }
277             remainingPacket = null;
278         } catch (IOException e) {
279             // Went over the packet limit; truncate
280             if (writtenQuestions == 0 && writtenAnswers == 0) {
281                 // No space to write even one record: just throw (as subclass of IOException)
282                 throw e;
283             }
284 
285             // Set the last valid position as the final position (not as a rewind)
286             writer.rewind(lastValidPos);
287             writer.clearRewind();
288 
289             remainingPacket = new MdnsPacket(packet.flags,
290                     packet.questions.subList(
291                             writtenQuestions, packet.questions.size()),
292                     packet.answers.subList(
293                             writtenAnswers, packet.answers.size()),
294                     Collections.emptyList(), /* authorityRecords */
295                     Collections.emptyList() /* additionalRecords */);
296         }
297 
298         final int len = writer.getWritePosition();
299         writer.rewind(flagsPos);
300         writer.writeUInt16(packet.flags | (remainingPacket == null ? 0 : FLAG_TRUNCATED));
301         writer.writeUInt16(writtenQuestions);
302         writer.writeUInt16(writtenAnswers);
303         writer.unrewind();
304 
305         return Pair.create(remainingPacket, len);
306     }
307 
308     /**
309      * Create Datagram packets from given MdnsPacket and InetSocketAddress.
310      *
311      * <p> If the MdnsPacket is too large for a single DatagramPacket, it will be split into
312      *     multiple DatagramPackets.
313      */
createQueryDatagramPackets( @onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)314     public static List<DatagramPacket> createQueryDatagramPackets(
315             @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
316             @NonNull InetSocketAddress destination) throws IOException {
317         final List<DatagramPacket> datagramPackets = new ArrayList<>();
318         MdnsPacket remainingPacket = packet;
319         while (remainingPacket != null) {
320             final Pair<MdnsPacket, Integer> result =
321                     writePossibleMdnsPacket(packetCreationBuffer, remainingPacket);
322             remainingPacket = result.first;
323             final int len = result.second;
324             final byte[] outBuffer = Arrays.copyOfRange(packetCreationBuffer, 0, len);
325             datagramPackets.add(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
326         }
327         return datagramPackets;
328     }
329 
330     /**
331      * Checks if the MdnsRecord needs to be renewed or not.
332      *
333      * <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one,
334      * so send the queries if half the TTL has passed.
335      */
isRecordRenewalNeeded(@onNull MdnsRecord mdnsRecord, final long now)336     public static boolean isRecordRenewalNeeded(@NonNull MdnsRecord mdnsRecord, final long now) {
337         return mdnsRecord.getTtl() > 0
338                 && mdnsRecord.getRemainingTTL(now) <= mdnsRecord.getTtl() / 2;
339     }
340 
341     /**
342      * Creates a new full subtype name with given service type and subtype labels.
343      *
344      * For example, given ["_http", "_tcp"] and "_printer", this method returns a new String array
345      * of ["_printer", "_sub", "_http", "_tcp"].
346      */
constructFullSubtype(String[] serviceType, String subtype)347     public static String[] constructFullSubtype(String[] serviceType, String subtype) {
348         String[] fullSubtype = new String[serviceType.length + 2];
349         fullSubtype[0] = subtype;
350         fullSubtype[1] = MdnsConstants.SUBTYPE_LABEL;
351         System.arraycopy(serviceType, 0, fullSubtype, 2, serviceType.length);
352         return fullSubtype;
353     }
354 
355     /** A wrapper class of {@link SystemClock} to be mocked in unit tests. */
356     public static class Clock {
357         /**
358          * @see SystemClock#elapsedRealtime
359          */
elapsedRealtime()360         public long elapsedRealtime() {
361             return SystemClock.elapsedRealtime();
362         }
363     }
364 }