1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.connectivity.mdns.util; 18 19 import static com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED; 20 21 import android.annotation.NonNull; 22 import android.annotation.Nullable; 23 import android.net.Network; 24 import android.os.Build; 25 import android.os.Handler; 26 import android.os.SystemClock; 27 import android.util.ArraySet; 28 import android.util.Pair; 29 30 import com.android.server.connectivity.mdns.MdnsConstants; 31 import com.android.server.connectivity.mdns.MdnsPacket; 32 import com.android.server.connectivity.mdns.MdnsPacketWriter; 33 import com.android.server.connectivity.mdns.MdnsRecord; 34 35 import java.io.IOException; 36 import java.net.DatagramPacket; 37 import java.net.InetSocketAddress; 38 import java.nio.ByteBuffer; 39 import java.nio.CharBuffer; 40 import java.nio.charset.Charset; 41 import java.nio.charset.CharsetEncoder; 42 import java.nio.charset.StandardCharsets; 43 import java.util.ArrayList; 44 import java.util.Arrays; 45 import java.util.Collections; 46 import java.util.HashSet; 47 import java.util.List; 48 import java.util.Set; 49 50 /** 51 * Mdns utility functions. 52 */ 53 public class MdnsUtils { 54 MdnsUtils()55 private MdnsUtils() { } 56 57 /** 58 * Convert the string to DNS case-insensitive lowercase 59 * 60 * Per rfc6762#page-46, accented characters are not defined to be automatically equivalent to 61 * their unaccented counterparts. So the "DNS lowercase" should be if character is A-Z then they 62 * transform into a-z. Otherwise, they are kept as-is. 63 */ toDnsLowerCase(@onNull String string)64 public static String toDnsLowerCase(@NonNull String string) { 65 final char[] outChars = new char[string.length()]; 66 for (int i = 0; i < string.length(); i++) { 67 outChars[i] = toDnsLowerCase(string.charAt(i)); 68 } 69 return new String(outChars); 70 } 71 72 /** 73 * Create a ArraySet or HashSet based on the sdk version. 74 */ newSet()75 public static <Type> Set<Type> newSet() { 76 if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { 77 return new ArraySet<>(); 78 } else { 79 return new HashSet<>(); 80 } 81 } 82 83 /** 84 * Convert the array of labels to DNS case-insensitive lowercase. 85 */ toDnsLabelsLowerCase(@onNull String[] labels)86 public static String[] toDnsLabelsLowerCase(@NonNull String[] labels) { 87 final String[] outStrings = new String[labels.length]; 88 for (int i = 0; i < labels.length; ++i) { 89 outStrings[i] = toDnsLowerCase(labels[i]); 90 } 91 return outStrings; 92 } 93 94 /** 95 * Compare two strings by DNS case-insensitive lowercase. 96 */ equalsIgnoreDnsCase(@ullable String a, @Nullable String b)97 public static boolean equalsIgnoreDnsCase(@Nullable String a, @Nullable String b) { 98 if (a == null || b == null) { 99 return a == null && b == null; 100 } 101 if (a.length() != b.length()) return false; 102 for (int i = 0; i < a.length(); i++) { 103 if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) { 104 return false; 105 } 106 } 107 return true; 108 } 109 110 /** 111 * Compare two set of DNS labels by DNS case-insensitive lowercase. 112 */ equalsDnsLabelIgnoreDnsCase(@onNull String[] a, @NonNull String[] b)113 public static boolean equalsDnsLabelIgnoreDnsCase(@NonNull String[] a, @NonNull String[] b) { 114 if (a == b) { 115 return true; 116 } 117 int length = a.length; 118 if (b.length != length) { 119 return false; 120 } 121 for (int i = 0; i < length; i++) { 122 if (!equalsIgnoreDnsCase(a[i], b[i])) { 123 return false; 124 } 125 } 126 return true; 127 } 128 129 /** 130 * Compare labels a equals b or a is suffix of b. 131 * 132 * @param a the type or subtype. 133 * @param b the base type 134 */ typeEqualsOrIsSubtype(@onNull String[] a, @NonNull String[] b)135 public static boolean typeEqualsOrIsSubtype(@NonNull String[] a, 136 @NonNull String[] b) { 137 return MdnsUtils.equalsDnsLabelIgnoreDnsCase(a, b) 138 || ((b.length == (a.length + 2)) 139 && MdnsUtils.equalsIgnoreDnsCase(b[1], MdnsConstants.SUBTYPE_LABEL) 140 && MdnsRecord.labelsAreSuffix(a, b)); 141 } 142 toDnsLowerCase(char a)143 private static char toDnsLowerCase(char a) { 144 return a >= 'A' && a <= 'Z' ? (char) (a + ('a' - 'A')) : a; 145 } 146 147 /*** Ensure that current running thread is same as given handler thread */ ensureRunningOnHandlerThread(@onNull Handler handler)148 public static void ensureRunningOnHandlerThread(@NonNull Handler handler) { 149 if (!isRunningOnHandlerThread(handler)) { 150 throw new IllegalStateException( 151 "Not running on Handler thread: " + Thread.currentThread().getName()); 152 } 153 } 154 155 /*** Check that current running thread is same as given handler thread */ isRunningOnHandlerThread(@onNull Handler handler)156 public static boolean isRunningOnHandlerThread(@NonNull Handler handler) { 157 if (handler.getLooper().getThread() == Thread.currentThread()) { 158 return true; 159 } 160 return false; 161 } 162 163 /*** Check whether the target network matches the current network */ isNetworkMatched(@ullable Network targetNetwork, @Nullable Network currentNetwork)164 public static boolean isNetworkMatched(@Nullable Network targetNetwork, 165 @Nullable Network currentNetwork) { 166 return targetNetwork == null || targetNetwork.equals(currentNetwork); 167 } 168 169 /*** Check whether the target network matches any of the current networks */ isAnyNetworkMatched(@ullable Network targetNetwork, Set<Network> currentNetworks)170 public static boolean isAnyNetworkMatched(@Nullable Network targetNetwork, 171 Set<Network> currentNetworks) { 172 if (targetNetwork == null) { 173 return !currentNetworks.isEmpty(); 174 } 175 return currentNetworks.contains(targetNetwork); 176 } 177 178 /** 179 * Truncate a service name to up to maxLength UTF-8 bytes. 180 */ truncateServiceName(@onNull String originalName, int maxLength)181 public static String truncateServiceName(@NonNull String originalName, int maxLength) { 182 // UTF-8 is at most 4 bytes per character; return early in the common case where 183 // the name can't possibly be over the limit given its string length. 184 if (originalName.length() <= maxLength / 4) return originalName; 185 186 final Charset utf8 = StandardCharsets.UTF_8; 187 final CharsetEncoder encoder = utf8.newEncoder(); 188 final ByteBuffer out = ByteBuffer.allocate(maxLength); 189 // encode will write as many characters as possible to the out buffer, and just 190 // return an overflow code if there were too many characters (no need to check the 191 // return code here, this method truncates the name on purpose). 192 encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */); 193 return new String(out.array(), 0, out.position(), utf8); 194 } 195 196 /** 197 * Write the mdns packet from given MdnsPacket. 198 */ writeMdnsPacket(@onNull MdnsPacketWriter writer, @NonNull MdnsPacket packet)199 public static void writeMdnsPacket(@NonNull MdnsPacketWriter writer, @NonNull MdnsPacket packet) 200 throws IOException { 201 writer.writeUInt16(packet.transactionId); // Transaction ID (advertisement: 0) 202 writer.writeUInt16(packet.flags); // Response, authoritative (rfc6762 18.4) 203 writer.writeUInt16(packet.questions.size()); // questions count 204 writer.writeUInt16(packet.answers.size()); // answers count 205 writer.writeUInt16(packet.authorityRecords.size()); // authority entries count 206 writer.writeUInt16(packet.additionalRecords.size()); // additional records count 207 208 for (MdnsRecord record : packet.questions) { 209 // Questions do not have TTL or data 210 record.writeHeaderFields(writer); 211 } 212 for (MdnsRecord record : packet.answers) { 213 record.write(writer, 0L); 214 } 215 for (MdnsRecord record : packet.authorityRecords) { 216 record.write(writer, 0L); 217 } 218 for (MdnsRecord record : packet.additionalRecords) { 219 record.write(writer, 0L); 220 } 221 } 222 223 /** 224 * Create a raw DNS packet. 225 */ createRawDnsPacket(@onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet)226 public static byte[] createRawDnsPacket(@NonNull byte[] packetCreationBuffer, 227 @NonNull MdnsPacket packet) throws IOException { 228 // TODO: support packets over size (send in multiple packets with TC bit set) 229 final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer); 230 writeMdnsPacket(writer, packet); 231 232 final int len = writer.getWritePosition(); 233 return Arrays.copyOfRange(packetCreationBuffer, 0, len); 234 } 235 236 /** 237 * Writes the possible query content of an MdnsPacket into the data buffer. 238 * 239 * <p>This method is specifically for query packets. It writes the question and answer sections 240 * into the data buffer only. 241 * 242 * @param packetCreationBuffer The data buffer for the query content. 243 * @param packet The MdnsPacket to be written into the data buffer. 244 * @return A Pair containing: 245 * 1. The remaining MdnsPacket data that could not fit in the buffer. 246 * 2. The length of the data written to the buffer. 247 */ 248 @Nullable writePossibleMdnsPacket( @onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet)249 private static Pair<MdnsPacket, Integer> writePossibleMdnsPacket( 250 @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet) throws IOException { 251 MdnsPacket remainingPacket; 252 final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer); 253 writer.writeUInt16(packet.transactionId); // Transaction ID 254 255 final int flagsPos = writer.getWritePosition(); 256 writer.writeUInt16(0); // Flags, written later 257 writer.writeUInt16(0); // questions count, written later 258 writer.writeUInt16(0); // answers count, written later 259 writer.writeUInt16(0); // authority entries count, empty session for query 260 writer.writeUInt16(0); // additional records count, empty session for query 261 262 int writtenQuestions = 0; 263 int writtenAnswers = 0; 264 int lastValidPos = writer.getWritePosition(); 265 try { 266 for (MdnsRecord record : packet.questions) { 267 // Questions do not have TTL or data 268 record.writeHeaderFields(writer); 269 writtenQuestions++; 270 lastValidPos = writer.getWritePosition(); 271 } 272 for (MdnsRecord record : packet.answers) { 273 record.write(writer, 0L); 274 writtenAnswers++; 275 lastValidPos = writer.getWritePosition(); 276 } 277 remainingPacket = null; 278 } catch (IOException e) { 279 // Went over the packet limit; truncate 280 if (writtenQuestions == 0 && writtenAnswers == 0) { 281 // No space to write even one record: just throw (as subclass of IOException) 282 throw e; 283 } 284 285 // Set the last valid position as the final position (not as a rewind) 286 writer.rewind(lastValidPos); 287 writer.clearRewind(); 288 289 remainingPacket = new MdnsPacket(packet.flags, 290 packet.questions.subList( 291 writtenQuestions, packet.questions.size()), 292 packet.answers.subList( 293 writtenAnswers, packet.answers.size()), 294 Collections.emptyList(), /* authorityRecords */ 295 Collections.emptyList() /* additionalRecords */); 296 } 297 298 final int len = writer.getWritePosition(); 299 writer.rewind(flagsPos); 300 writer.writeUInt16(packet.flags | (remainingPacket == null ? 0 : FLAG_TRUNCATED)); 301 writer.writeUInt16(writtenQuestions); 302 writer.writeUInt16(writtenAnswers); 303 writer.unrewind(); 304 305 return Pair.create(remainingPacket, len); 306 } 307 308 /** 309 * Create Datagram packets from given MdnsPacket and InetSocketAddress. 310 * 311 * <p> If the MdnsPacket is too large for a single DatagramPacket, it will be split into 312 * multiple DatagramPackets. 313 */ createQueryDatagramPackets( @onNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)314 public static List<DatagramPacket> createQueryDatagramPackets( 315 @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet, 316 @NonNull InetSocketAddress destination) throws IOException { 317 final List<DatagramPacket> datagramPackets = new ArrayList<>(); 318 MdnsPacket remainingPacket = packet; 319 while (remainingPacket != null) { 320 final Pair<MdnsPacket, Integer> result = 321 writePossibleMdnsPacket(packetCreationBuffer, remainingPacket); 322 remainingPacket = result.first; 323 final int len = result.second; 324 final byte[] outBuffer = Arrays.copyOfRange(packetCreationBuffer, 0, len); 325 datagramPackets.add(new DatagramPacket(outBuffer, 0, outBuffer.length, destination)); 326 } 327 return datagramPackets; 328 } 329 330 /** 331 * Checks if the MdnsRecord needs to be renewed or not. 332 * 333 * <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one, 334 * so send the queries if half the TTL has passed. 335 */ isRecordRenewalNeeded(@onNull MdnsRecord mdnsRecord, final long now)336 public static boolean isRecordRenewalNeeded(@NonNull MdnsRecord mdnsRecord, final long now) { 337 return mdnsRecord.getTtl() > 0 338 && mdnsRecord.getRemainingTTL(now) <= mdnsRecord.getTtl() / 2; 339 } 340 341 /** 342 * Creates a new full subtype name with given service type and subtype labels. 343 * 344 * For example, given ["_http", "_tcp"] and "_printer", this method returns a new String array 345 * of ["_printer", "_sub", "_http", "_tcp"]. 346 */ constructFullSubtype(String[] serviceType, String subtype)347 public static String[] constructFullSubtype(String[] serviceType, String subtype) { 348 String[] fullSubtype = new String[serviceType.length + 2]; 349 fullSubtype[0] = subtype; 350 fullSubtype[1] = MdnsConstants.SUBTYPE_LABEL; 351 System.arraycopy(serviceType, 0, fullSubtype, 2, serviceType.length); 352 return fullSubtype; 353 } 354 355 /** A wrapper class of {@link SystemClock} to be mocked in unit tests. */ 356 public static class Clock { 357 /** 358 * @see SystemClock#elapsedRealtime 359 */ elapsedRealtime()360 public long elapsedRealtime() { 361 return SystemClock.elapsedRealtime(); 362 } 363 } 364 }