1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
20 import com.android.server.connectivity.mdns.util.MdnsUtils;
21 
22 import java.io.IOException;
23 import java.net.DatagramPacket;
24 import java.net.SocketAddress;
25 import java.util.HashMap;
26 import java.util.Map;
27 
28 /** Simple encoder for mDNS packets. */
29 public class MdnsPacketWriter {
30     private static final int MDNS_POINTER_MASK = 0xC000;
31     private final byte[] data;
32     private final Map<Integer, String[]> labelDictionary = new HashMap<>();
33     private int pos = 0;
34     private int savedWritePos = -1;
35 
36     /**
37      * Constructs a writer for a new packet.
38      *
39      * @param maxSize The maximum size of a packet.
40      */
MdnsPacketWriter(int maxSize)41     public MdnsPacketWriter(int maxSize) {
42         if (maxSize <= 0) {
43             throw new IllegalArgumentException("invalid size");
44         }
45 
46         data = new byte[maxSize];
47     }
48 
49     /**
50      * Constructs a writer for a new packet.
51      *
52      * @param buffer The buffer to write to.
53      */
MdnsPacketWriter(byte[] buffer)54     public MdnsPacketWriter(byte[] buffer) {
55         data = buffer;
56     }
57 
58     /** Returns the current write position. */
getWritePosition()59     public int getWritePosition() {
60         return pos;
61     }
62 
63     /**
64      * Saves the current write position and then rewinds the write position by the given number of
65      * bytes. This is useful for updating length fields earlier in the packet. Rewinds cannot be
66      * nested.
67      *
68      * @param position The position to rewind to.
69      * @throws IOException If the count would go beyond the beginning of the packet, or if there is
70      *                     already a rewind in effect.
71      */
rewind(int position)72     public void rewind(int position) throws IOException {
73         if ((savedWritePos != -1) || (position > pos) || (position < 0)) {
74             throw new IOException("invalid rewind");
75         }
76 
77         savedWritePos = pos;
78         pos = position;
79     }
80 
81     /**
82      * Sets the current write position to what it was prior to the last rewind.
83      *
84      * @throws IOException If there was no rewind in effect.
85      */
unrewind()86     public void unrewind() throws IOException {
87         if (savedWritePos == -1) {
88             throw new IOException("no rewind is in effect");
89         }
90         pos = savedWritePos;
91         savedWritePos = -1;
92     }
93 
94     /** Clears any rewind state. */
clearRewind()95     public void clearRewind() {
96         savedWritePos = -1;
97     }
98 
99     /**
100      * Writes an unsigned 8-bit integer.
101      *
102      * @param value The value to write.
103      * @throws IOException If there is not enough space remaining in the packet.
104      */
writeUInt8(int value)105     public void writeUInt8(int value) throws IOException {
106         checkRemaining(1);
107         data[pos++] = (byte) (value & 0xFF);
108     }
109 
110     /**
111      * Writes an unsigned 16-bit integer.
112      *
113      * @param value The value to write.
114      * @throws IOException If there is not enough space remaining in the packet.
115      */
writeUInt16(int value)116     public void writeUInt16(int value) throws IOException {
117         checkRemaining(2);
118         data[pos++] = (byte) ((value >>> 8) & 0xFF);
119         data[pos++] = (byte) (value & 0xFF);
120     }
121 
122     /**
123      * Writes an unsigned 32-bit integer.
124      *
125      * @param value The value to write.
126      * @throws IOException If there is not enough space remaining in the packet.
127      */
writeUInt32(long value)128     public void writeUInt32(long value) throws IOException {
129         checkRemaining(4);
130         data[pos++] = (byte) ((value >>> 24) & 0xFF);
131         data[pos++] = (byte) ((value >>> 16) & 0xFF);
132         data[pos++] = (byte) ((value >>> 8) & 0xFF);
133         data[pos++] = (byte) (value & 0xFF);
134     }
135 
136     /**
137      * Writes a specific number of bytes.
138      *
139      * @param data The array to write.
140      * @throws IOException If there is not enough space remaining in the packet.
141      */
writeBytes(byte[] data)142     public void writeBytes(byte[] data) throws IOException {
143         checkRemaining(data.length);
144         System.arraycopy(data, 0, this.data, pos, data.length);
145         pos += data.length;
146     }
147 
148     /**
149      * Writes a string.
150      *
151      * @param value The string to write.
152      * @throws IOException If there is not enough space remaining in the packet.
153      */
writeString(String value)154     public void writeString(String value) throws IOException {
155         byte[] utf8 = value.getBytes(MdnsConstants.getUtf8Charset());
156         writeUInt8(utf8.length);
157         writeBytes(utf8);
158     }
159 
writeTextEntry(TextEntry textEntry)160     public void writeTextEntry(TextEntry textEntry) throws IOException {
161         byte[] bytes = textEntry.toBytes();
162         writeUInt8(bytes.length);
163         writeBytes(bytes);
164     }
165 
166     /**
167      * Writes a series of labels. Uses name compression.
168      *
169      * @param labels The labels to write.
170      * @throws IOException If there is not enough space remaining in the packet.
171      */
writeLabels(String[] labels)172     public void writeLabels(String[] labels) throws IOException {
173         // See section 4.1.4 of RFC 1035 (http://tools.ietf.org/html/rfc1035) for a description
174         // of the name compression method used here.
175 
176         int suffixLength = 0;
177         int suffixPointer = 0;
178 
179         for (Map.Entry<Integer, String[]> entry : labelDictionary.entrySet()) {
180             int existingOffset = entry.getKey();
181             String[] existingLabels = entry.getValue();
182 
183             if (MdnsUtils.equalsDnsLabelIgnoreDnsCase(existingLabels, labels)) {
184                 writePointer(existingOffset);
185                 return;
186             } else if (MdnsRecord.labelsAreSuffix(existingLabels, labels)) {
187                 // Keep track of the longest matching suffix so far.
188                 if (existingLabels.length > suffixLength) {
189                     suffixLength = existingLabels.length;
190                     suffixPointer = existingOffset;
191                 }
192             }
193         }
194 
195         final int[] offsets;
196         if (suffixLength > 0) {
197             offsets = writePartialLabelsNoCompression(labels, labels.length - suffixLength);
198             writePointer(suffixPointer);
199         } else {
200             offsets = writeLabelsNoCompression(labels);
201         }
202 
203         // Add entries to the label dictionary for each suffix of the label list, including
204         // the whole list itself.
205         // Do not replace the last suffixLength suffixes that already have dictionary entries.
206         for (int i = 0, len = labels.length; i < labels.length - suffixLength; ++i, --len) {
207             String[] value = new String[len];
208             System.arraycopy(labels, i, value, 0, len);
209             labelDictionary.put(offsets[i], value);
210         }
211     }
212 
writePartialLabelsNoCompression(String[] labels, int count)213     private int[] writePartialLabelsNoCompression(String[] labels, int count) throws IOException {
214         int[] offsets = new int[count];
215         for (int i = 0; i < count; ++i) {
216             offsets[i] = getWritePosition();
217             writeString(labels[i]);
218         }
219         return offsets;
220     }
221 
222     /**
223      * Write a series a labels, without using name compression.
224      *
225      * @return The offsets where each label was written to.
226      */
writeLabelsNoCompression(String[] labels)227     public int[] writeLabelsNoCompression(String[] labels) throws IOException {
228         final int[] offsets = writePartialLabelsNoCompression(labels, labels.length);
229         writeUInt8(0); // NUL terminator
230         return offsets;
231     }
232 
233     /** Returns the number of bytes that can still be written. */
getRemaining()234     public int getRemaining() {
235         return data.length - pos;
236     }
237 
238     // Writes a pointer to a label.
writePointer(int offset)239     private void writePointer(int offset) throws IOException {
240         writeUInt16(MDNS_POINTER_MASK | offset);
241     }
242 
243     // Checks if the remaining space in the packet is at least |count|.
checkRemaining(int count)244     private void checkRemaining(int count) throws IOException {
245         if (getRemaining() < count) {
246             throw new IOException();
247         }
248     }
249 
250     /** Builds and returns the packet. */
getPacket(SocketAddress destAddress)251     public DatagramPacket getPacket(SocketAddress destAddress) throws IOException {
252         return new DatagramPacket(data, pos, destAddress);
253     }
254 }
255