1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 import android.util.SparseArray;
22 
23 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
24 
25 import java.io.EOFException;
26 import java.io.IOException;
27 import java.net.DatagramPacket;
28 import java.util.ArrayList;
29 import java.util.List;
30 import java.util.Locale;
31 
32 /** Simple decoder for mDNS packets. */
33 public class MdnsPacketReader {
34     // The total length in bytes should be less than 255 bytes anyway (including labels and label
35     // length) per RFC9267, so limit the number of labels to 128 (each label is 2 bytes with the
36     // length).
37     // https://www.rfc-editor.org/rfc/rfc9267.html#name-label-and-name-length-valid
38     private static final int LABEL_COUNT_LIMIT = 128;
39     private final byte[] buf;
40     private final int count;
41     private final SparseArray<LabelEntry> labelDictionary;
42     private final MdnsFeatureFlags mMdnsFeatureFlags;
43     private int pos;
44     private int limit;
45 
46     /** Constructs a reader for the given packet. */
MdnsPacketReader(DatagramPacket packet)47     public MdnsPacketReader(DatagramPacket packet) {
48         this(packet.getData(), packet.getLength(), MdnsFeatureFlags.newBuilder().build());
49     }
50 
51     /** Constructs a reader for the given packet. */
MdnsPacketReader(byte[] buffer, int length, @NonNull MdnsFeatureFlags mdnsFeatureFlags)52     public MdnsPacketReader(byte[] buffer, int length, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
53         buf = buffer;
54         count = length;
55         pos = 0;
56         limit = -1;
57         labelDictionary = new SparseArray<>(16);
58         mMdnsFeatureFlags = mdnsFeatureFlags;
59     }
60 
61     /**
62      * Sets a temporary limit (from the current read position) for subsequent reads. Any attempt to
63      * read past this limit will result in an EOFException.
64      *
65      * @param limit The new limit.
66      * @throws IOException If there is insufficient data for the new limit.
67      */
setLimit(int limit)68     public void setLimit(int limit) throws IOException {
69         if (limit >= 0) {
70             if (pos + limit <= count) {
71                 this.limit = pos + limit;
72             } else {
73                 throw new IOException(
74                         String.format(
75                                 Locale.ROOT,
76                                 "attempt to set limit beyond available data: %d exceeds %d",
77                                 pos + limit,
78                                 count));
79             }
80         }
81     }
82 
83     /** Clears the limit set by {@link #setLimit}. */
clearLimit()84     public void clearLimit() {
85         limit = -1;
86     }
87 
88     /**
89      * Returns the number of bytes left to read, between the current read position and either the
90      * limit (if set) or the end of the packet.
91      */
getRemaining()92     public int getRemaining() {
93         return (limit >= 0 ? limit : count) - pos;
94     }
95 
96     /**
97      * Reads an unsigned 8-bit integer.
98      *
99      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
100      *                      read.
101      */
readUInt8()102     public int readUInt8() throws EOFException {
103         checkRemaining(1);
104         byte val = buf[pos++];
105         return val & 0xFF;
106     }
107 
108     /**
109      * Reads an unsigned 16-bit integer.
110      *
111      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
112      *                      read.
113      */
readUInt16()114     public int readUInt16() throws EOFException {
115         checkRemaining(2);
116         int val = (buf[pos++] & 0xFF) << 8;
117         val |= (buf[pos++]) & 0xFF;
118         return val;
119     }
120 
121     /**
122      * Reads an unsigned 32-bit integer.
123      *
124      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
125      *                      read.
126      */
readUInt32()127     public long readUInt32() throws EOFException {
128         checkRemaining(4);
129         long val = (long) (buf[pos++] & 0xFF) << 24;
130         val |= (long) (buf[pos++] & 0xFF) << 16;
131         val |= (long) (buf[pos++] & 0xFF) << 8;
132         val |= buf[pos++] & 0xFF;
133         return val;
134     }
135 
136     /**
137      * Reads a sequence of labels and returns them as an array of strings. A sequence of labels is
138      * either a sequence of strings terminated by a NUL byte, a sequence of strings terminated by a
139      * pointer, or a pointer.
140      *
141      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
142      *                      read.
143      * @throws IOException  If invalid data is read.
144      */
readLabels()145     public String[] readLabels() throws IOException {
146         List<String> result = new ArrayList<>(5);
147         LabelEntry previousEntry = null;
148         int tracingHops = 0;
149 
150         while (getRemaining() > 0) {
151             byte nextByte = peekByte();
152 
153             if (nextByte == 0) {
154                 // A NUL byte terminates a sequence of labels.
155                 skip(1);
156                 break;
157             }
158 
159             int currentOffset = pos;
160 
161             boolean isLabelPointer = (nextByte & 0xC0) == 0xC0;
162             if (isLabelPointer) {
163                 // A pointer terminates a sequence of labels. Store the pointer value in the
164                 // previous label entry.
165                 int labelOffset = ((readUInt8() & 0x3F) << 8) | (readUInt8() & 0xFF);
166                 if (previousEntry != null) {
167                     previousEntry.nextOffset = labelOffset;
168                 }
169 
170                 // Follow the chain of labels starting at this pointer, adding all of them onto the
171                 // result.
172                 while (labelOffset != 0) {
173                     if (mMdnsFeatureFlags.mIsLabelCountLimitEnabled
174                             && tracingHops > LABEL_COUNT_LIMIT) {
175                         throw new IOException("Invalid MDNS response packet: Too many labels.");
176                     }
177                     LabelEntry entry = labelDictionary.get(labelOffset);
178                     if (entry == null) {
179                         throw new IOException(
180                                 String.format(Locale.ROOT, "Invalid label pointer: %04X",
181                                         labelOffset));
182                     }
183                     result.add(entry.label);
184                     labelOffset = entry.nextOffset;
185                     tracingHops++;
186                 }
187                 break;
188             } else {
189                 // It's an ordinary label. Chain it onto the previous label entry (if any), and add
190                 // it onto the result.
191                 String val = readString();
192                 LabelEntry newEntry = new LabelEntry(val);
193                 labelDictionary.put(currentOffset, newEntry);
194 
195                 if (previousEntry != null) {
196                     previousEntry.nextOffset = currentOffset;
197                 }
198                 previousEntry = newEntry;
199                 result.add(val);
200             }
201         }
202 
203         return result.toArray(new String[result.size()]);
204     }
205 
206     /**
207      * Reads a length-prefixed string.
208      *
209      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
210      *                      read.
211      */
readString()212     public String readString() throws EOFException {
213         int len = readUInt8();
214         checkRemaining(len);
215         String val = new String(buf, pos, len, MdnsConstants.getUtf8Charset());
216         pos += len;
217         return val;
218     }
219 
220     @Nullable
readTextEntry()221     public TextEntry readTextEntry() throws EOFException {
222         int len = readUInt8();
223         checkRemaining(len);
224         byte[] bytes = new byte[len];
225         System.arraycopy(buf, pos, bytes, 0, bytes.length);
226         pos += len;
227         return TextEntry.fromBytes(bytes);
228     }
229 
230     /**
231      * Reads a specific number of bytes.
232      *
233      * @param bytes The array to fill.
234      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
235      *                      read.
236      */
readBytes(byte[] bytes)237     public void readBytes(byte[] bytes) throws EOFException {
238         checkRemaining(bytes.length);
239         System.arraycopy(buf, pos, bytes, 0, bytes.length);
240         pos += bytes.length;
241     }
242 
243     /**
244      * Skips over the given number of bytes.
245      *
246      * @param count The number of bytes to read and discard.
247      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
248      *                      read.
249      */
skip(int count)250     public void skip(int count) throws EOFException {
251         checkRemaining(count);
252         pos += count;
253     }
254 
255     /**
256      * Peeks at and returns the next byte in the packet, without advancing the read position.
257      *
258      * @throws EOFException If there are not enough bytes remaining in the packet to satisfy the
259      *                      read.
260      */
peekByte()261     public byte peekByte() throws EOFException {
262         checkRemaining(1);
263         return buf[pos];
264     }
265 
266     /** Returns the current byte position of the reader for the data packet. */
getPosition()267     public int getPosition() {
268         return pos;
269     }
270 
271     // Checks if the number of remaining bytes to be read in the packet is at least |count|.
checkRemaining(int count)272     private void checkRemaining(int count) throws EOFException {
273         if (getRemaining() < count) {
274             throw new EOFException();
275         }
276     }
277 
278     private static class LabelEntry {
279         public final String label;
280         public int nextOffset = 0;
281 
LabelEntry(String label)282         public LabelEntry(String label) {
283             this.label = label;
284         }
285     }
286 }
287