1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.connectivity.mdns; 18 19 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry; 20 import com.android.server.connectivity.mdns.util.MdnsUtils; 21 22 import java.io.IOException; 23 import java.net.DatagramPacket; 24 import java.net.SocketAddress; 25 import java.util.HashMap; 26 import java.util.Map; 27 28 /** Simple encoder for mDNS packets. */ 29 public class MdnsPacketWriter { 30 private static final int MDNS_POINTER_MASK = 0xC000; 31 private final byte[] data; 32 private final Map<Integer, String[]> labelDictionary = new HashMap<>(); 33 private int pos = 0; 34 private int savedWritePos = -1; 35 36 /** 37 * Constructs a writer for a new packet. 38 * 39 * @param maxSize The maximum size of a packet. 40 */ MdnsPacketWriter(int maxSize)41 public MdnsPacketWriter(int maxSize) { 42 if (maxSize <= 0) { 43 throw new IllegalArgumentException("invalid size"); 44 } 45 46 data = new byte[maxSize]; 47 } 48 49 /** 50 * Constructs a writer for a new packet. 51 * 52 * @param buffer The buffer to write to. 53 */ MdnsPacketWriter(byte[] buffer)54 public MdnsPacketWriter(byte[] buffer) { 55 data = buffer; 56 } 57 58 /** Returns the current write position. */ getWritePosition()59 public int getWritePosition() { 60 return pos; 61 } 62 63 /** 64 * Saves the current write position and then rewinds the write position by the given number of 65 * bytes. This is useful for updating length fields earlier in the packet. Rewinds cannot be 66 * nested. 67 * 68 * @param position The position to rewind to. 69 * @throws IOException If the count would go beyond the beginning of the packet, or if there is 70 * already a rewind in effect. 71 */ rewind(int position)72 public void rewind(int position) throws IOException { 73 if ((savedWritePos != -1) || (position > pos) || (position < 0)) { 74 throw new IOException("invalid rewind"); 75 } 76 77 savedWritePos = pos; 78 pos = position; 79 } 80 81 /** 82 * Sets the current write position to what it was prior to the last rewind. 83 * 84 * @throws IOException If there was no rewind in effect. 85 */ unrewind()86 public void unrewind() throws IOException { 87 if (savedWritePos == -1) { 88 throw new IOException("no rewind is in effect"); 89 } 90 pos = savedWritePos; 91 savedWritePos = -1; 92 } 93 94 /** Clears any rewind state. */ clearRewind()95 public void clearRewind() { 96 savedWritePos = -1; 97 } 98 99 /** 100 * Writes an unsigned 8-bit integer. 101 * 102 * @param value The value to write. 103 * @throws IOException If there is not enough space remaining in the packet. 104 */ writeUInt8(int value)105 public void writeUInt8(int value) throws IOException { 106 checkRemaining(1); 107 data[pos++] = (byte) (value & 0xFF); 108 } 109 110 /** 111 * Writes an unsigned 16-bit integer. 112 * 113 * @param value The value to write. 114 * @throws IOException If there is not enough space remaining in the packet. 115 */ writeUInt16(int value)116 public void writeUInt16(int value) throws IOException { 117 checkRemaining(2); 118 data[pos++] = (byte) ((value >>> 8) & 0xFF); 119 data[pos++] = (byte) (value & 0xFF); 120 } 121 122 /** 123 * Writes an unsigned 32-bit integer. 124 * 125 * @param value The value to write. 126 * @throws IOException If there is not enough space remaining in the packet. 127 */ writeUInt32(long value)128 public void writeUInt32(long value) throws IOException { 129 checkRemaining(4); 130 data[pos++] = (byte) ((value >>> 24) & 0xFF); 131 data[pos++] = (byte) ((value >>> 16) & 0xFF); 132 data[pos++] = (byte) ((value >>> 8) & 0xFF); 133 data[pos++] = (byte) (value & 0xFF); 134 } 135 136 /** 137 * Writes a specific number of bytes. 138 * 139 * @param data The array to write. 140 * @throws IOException If there is not enough space remaining in the packet. 141 */ writeBytes(byte[] data)142 public void writeBytes(byte[] data) throws IOException { 143 checkRemaining(data.length); 144 System.arraycopy(data, 0, this.data, pos, data.length); 145 pos += data.length; 146 } 147 148 /** 149 * Writes a string. 150 * 151 * @param value The string to write. 152 * @throws IOException If there is not enough space remaining in the packet. 153 */ writeString(String value)154 public void writeString(String value) throws IOException { 155 byte[] utf8 = value.getBytes(MdnsConstants.getUtf8Charset()); 156 writeUInt8(utf8.length); 157 writeBytes(utf8); 158 } 159 writeTextEntry(TextEntry textEntry)160 public void writeTextEntry(TextEntry textEntry) throws IOException { 161 byte[] bytes = textEntry.toBytes(); 162 writeUInt8(bytes.length); 163 writeBytes(bytes); 164 } 165 166 /** 167 * Writes a series of labels. Uses name compression. 168 * 169 * @param labels The labels to write. 170 * @throws IOException If there is not enough space remaining in the packet. 171 */ writeLabels(String[] labels)172 public void writeLabels(String[] labels) throws IOException { 173 // See section 4.1.4 of RFC 1035 (http://tools.ietf.org/html/rfc1035) for a description 174 // of the name compression method used here. 175 176 int suffixLength = 0; 177 int suffixPointer = 0; 178 179 for (Map.Entry<Integer, String[]> entry : labelDictionary.entrySet()) { 180 int existingOffset = entry.getKey(); 181 String[] existingLabels = entry.getValue(); 182 183 if (MdnsUtils.equalsDnsLabelIgnoreDnsCase(existingLabels, labels)) { 184 writePointer(existingOffset); 185 return; 186 } else if (MdnsRecord.labelsAreSuffix(existingLabels, labels)) { 187 // Keep track of the longest matching suffix so far. 188 if (existingLabels.length > suffixLength) { 189 suffixLength = existingLabels.length; 190 suffixPointer = existingOffset; 191 } 192 } 193 } 194 195 final int[] offsets; 196 if (suffixLength > 0) { 197 offsets = writePartialLabelsNoCompression(labels, labels.length - suffixLength); 198 writePointer(suffixPointer); 199 } else { 200 offsets = writeLabelsNoCompression(labels); 201 } 202 203 // Add entries to the label dictionary for each suffix of the label list, including 204 // the whole list itself. 205 // Do not replace the last suffixLength suffixes that already have dictionary entries. 206 for (int i = 0, len = labels.length; i < labels.length - suffixLength; ++i, --len) { 207 String[] value = new String[len]; 208 System.arraycopy(labels, i, value, 0, len); 209 labelDictionary.put(offsets[i], value); 210 } 211 } 212 writePartialLabelsNoCompression(String[] labels, int count)213 private int[] writePartialLabelsNoCompression(String[] labels, int count) throws IOException { 214 int[] offsets = new int[count]; 215 for (int i = 0; i < count; ++i) { 216 offsets[i] = getWritePosition(); 217 writeString(labels[i]); 218 } 219 return offsets; 220 } 221 222 /** 223 * Write a series a labels, without using name compression. 224 * 225 * @return The offsets where each label was written to. 226 */ writeLabelsNoCompression(String[] labels)227 public int[] writeLabelsNoCompression(String[] labels) throws IOException { 228 final int[] offsets = writePartialLabelsNoCompression(labels, labels.length); 229 writeUInt8(0); // NUL terminator 230 return offsets; 231 } 232 233 /** Returns the number of bytes that can still be written. */ getRemaining()234 public int getRemaining() { 235 return data.length - pos; 236 } 237 238 // Writes a pointer to a label. writePointer(int offset)239 private void writePointer(int offset) throws IOException { 240 writeUInt16(MDNS_POINTER_MASK | offset); 241 } 242 243 // Checks if the remaining space in the packet is at least |count|. checkRemaining(int count)244 private void checkRemaining(int count) throws IOException { 245 if (getRemaining() < count) { 246 throw new IOException(); 247 } 248 } 249 250 /** Builds and returns the packet. */ getPacket(SocketAddress destAddress)251 public DatagramPacket getPacket(SocketAddress destAddress) throws IOException { 252 return new DatagramPacket(data, pos, destAddress); 253 } 254 } 255