1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 import static java.util.concurrent.TimeUnit.SECONDS;
21 
22 import android.annotation.Nullable;
23 import android.os.SystemClock;
24 import android.text.TextUtils;
25 
26 import androidx.annotation.VisibleForTesting;
27 
28 import com.android.server.connectivity.mdns.util.MdnsUtils;
29 
30 import java.io.IOException;
31 import java.util.Arrays;
32 import java.util.Objects;
33 
34 /**
35  * Abstract base class for mDNS records. Stores the header fields and provides methods for reading
36  * the record from and writing it to a packet.
37  */
38 public abstract class MdnsRecord {
39     public static final int TYPE_A = 0x0001;
40     public static final int TYPE_AAAA = 0x001C;
41     public static final int TYPE_PTR = 0x000C;
42     public static final int TYPE_SRV = 0x0021;
43     public static final int TYPE_TXT = 0x0010;
44     public static final int TYPE_KEY = 0x0019;
45     public static final int TYPE_NSEC = 0x002f;
46     public static final int TYPE_ANY = 0x00ff;
47 
48     private static final int FLAG_CACHE_FLUSH = 0x8000;
49 
50     public static final long RECEIPT_TIME_NOT_SENT = 0L;
51     public static final int CLASS_ANY = 0x00ff;
52     /** Max label length as per RFC 1034/1035 */
53     public static final int MAX_LABEL_LENGTH = 63;
54 
55     /** Status indicating that the record is current. */
56     public static final int STATUS_OK = 0;
57     /** Status indicating that the record has expired (TTL reached 0). */
58     public static final int STATUS_EXPIRED = 1;
59     /** Status indicating that the record should be refreshed (Less than half of TTL remains.) */
60     public static final int STATUS_NEEDS_REFRESH = 2;
61 
62     protected final String[] name;
63     private final int type;
64     private final int cls;
65     private final long receiptTimeMillis;
66     private final long ttlMillis;
67     private Object key;
68 
69     /**
70      * Constructs a new record with the given name and type.
71      *
72      * @param reader The reader to read the record from.
73      * @param isQuestion Whether the record was included in the questions part of the message.
74      * @throws IOException If an error occurs while reading the packet.
75      */
MdnsRecord(String[] name, int type, MdnsPacketReader reader, boolean isQuestion)76     protected MdnsRecord(String[] name, int type, MdnsPacketReader reader, boolean isQuestion)
77             throws IOException {
78         this.name = name;
79         this.type = type;
80         cls = reader.readUInt16();
81         receiptTimeMillis = SystemClock.elapsedRealtime();
82 
83         if (isQuestion) {
84             // Questions do not have TTL or data
85             ttlMillis = 0L;
86         } else {
87             ttlMillis = SECONDS.toMillis(reader.readUInt32());
88             int dataLength = reader.readUInt16();
89 
90             reader.setLimit(dataLength);
91             readData(reader);
92             reader.clearLimit();
93         }
94     }
95 
96     /**
97      * Constructs a new record with the given name and type.
98      *
99      * @param reader The reader to read the record from.
100      * @throws IOException If an error occurs while reading the packet.
101      */
102     // call to readData(com.android.server.connectivity.mdns.MdnsPacketReader) not allowed on given
103     // receiver.
104     @SuppressWarnings("nullness:method.invocation.invalid")
MdnsRecord(String[] name, int type, MdnsPacketReader reader)105     protected MdnsRecord(String[] name, int type, MdnsPacketReader reader) throws IOException {
106         this(name, type, reader, false);
107     }
108 
109     /**
110      * Constructs a new record with the given properties.
111      */
MdnsRecord(String[] name, int type, int cls, long receiptTimeMillis, boolean cacheFlush, long ttlMillis)112     protected MdnsRecord(String[] name, int type, int cls, long receiptTimeMillis,
113             boolean cacheFlush, long ttlMillis) {
114         this.name = name;
115         this.type = type;
116         this.cls = cls | (cacheFlush ? FLAG_CACHE_FLUSH : 0);
117         this.receiptTimeMillis = receiptTimeMillis;
118         this.ttlMillis = ttlMillis;
119     }
120 
121     /**
122      * Converts an array of labels into their dot-separated string representation. This method
123      * should
124      * be used for logging purposes only.
125      */
labelsToString(String[] labels)126     public static String labelsToString(String[] labels) {
127         if (labels == null) {
128             return null;
129         }
130         return TextUtils.join(".", labels);
131     }
132 
133     /** Tests if |list1| is a suffix of |list2|. */
labelsAreSuffix(String[] list1, String[] list2)134     public static boolean labelsAreSuffix(String[] list1, String[] list2) {
135         int offset = list2.length - list1.length;
136 
137         if (offset < 1) {
138             return false;
139         }
140 
141         for (int i = 0; i < list1.length; ++i) {
142             if (!MdnsUtils.equalsIgnoreDnsCase(list1[i], list2[i + offset])) {
143                 return false;
144             }
145         }
146 
147         return true;
148     }
149 
150     /** Returns the record's receipt (creation) time. */
getReceiptTime()151     public final long getReceiptTime() {
152         return receiptTimeMillis;
153     }
154 
155     /** Returns the record's name. */
getName()156     public String[] getName() {
157         return name;
158     }
159 
160     /** Returns the record's original TTL, in milliseconds. */
getTtl()161     public final long getTtl() {
162         return ttlMillis;
163     }
164 
165     /** Returns the record's type. */
getType()166     public final int getType() {
167         return type;
168     }
169 
170     /** Return the record's class. */
getRecordClass()171     public final int getRecordClass() {
172         return cls & ~FLAG_CACHE_FLUSH;
173     }
174 
175     /** Return whether the cache flush flag is set. */
getCacheFlush()176     public final boolean getCacheFlush() {
177         return (cls & FLAG_CACHE_FLUSH) != 0;
178     }
179 
180     /**
181      * For questions, returns whether a unicast reply was requested.
182      *
183      * In practice this is identical to {@link #getCacheFlush()}, as the "cache flush" flag in
184      * replies is the same as "unicast reply requested" in questions.
185      */
isUnicastReplyRequested()186     public final boolean isUnicastReplyRequested() {
187         return (cls & MdnsConstants.QCLASS_UNICAST) != 0;
188     }
189 
190     /**
191      * Returns the record's remaining TTL.
192      *
193      * If the record was not sent yet (receipt time {@link #RECEIPT_TIME_NOT_SENT}), this is the
194      * original TTL of the record.
195      * @param now The current system time.
196      * @return The remaning TTL, in milliseconds.
197      */
getRemainingTTL(final long now)198     public long getRemainingTTL(final long now) {
199         if (receiptTimeMillis == RECEIPT_TIME_NOT_SENT) {
200             return ttlMillis;
201         }
202 
203         long age = now - receiptTimeMillis;
204         if (age > ttlMillis) {
205             return 0;
206         }
207 
208         return ttlMillis - age;
209     }
210 
211     /**
212      * Reads the record's payload from a packet.
213      *
214      * @param reader The reader to use.
215      * @throws IOException If an I/O error occurs.
216      */
readData(MdnsPacketReader reader)217     protected abstract void readData(MdnsPacketReader reader) throws IOException;
218 
219     /**
220      * Write the first fields of the record, which are common fields for questions and answers.
221      *
222      * @param writer The writer to use.
223      */
writeHeaderFields(MdnsPacketWriter writer)224     public final void writeHeaderFields(MdnsPacketWriter writer) throws IOException {
225         writer.writeLabels(name);
226         writer.writeUInt16(type);
227         writer.writeUInt16(cls);
228     }
229 
230     /**
231      * Writes the record to a packet.
232      *
233      * @param writer The writer to use.
234      * @param now    The current system time. This is used when writing the updated TTL.
235      */
236     @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
write(MdnsPacketWriter writer, long now)237     public final void write(MdnsPacketWriter writer, long now) throws IOException {
238         writeHeaderFields(writer);
239 
240         writer.writeUInt32(MILLISECONDS.toSeconds(getRemainingTTL(now)));
241 
242         int dataLengthPos = writer.getWritePosition();
243         writer.writeUInt16(0); // data length
244         int dataPos = writer.getWritePosition();
245 
246         writeData(writer);
247 
248         // Calculate amount of data written, and overwrite the data field earlier in the packet.
249         int endPos = writer.getWritePosition();
250         int dataLength = endPos - dataPos;
251         writer.rewind(dataLengthPos);
252         writer.writeUInt16(dataLength);
253         writer.unrewind();
254     }
255 
256     /**
257      * Writes the record's payload to a packet.
258      *
259      * @param writer The writer to use.
260      * @throws IOException If an I/O error occurs.
261      */
writeData(MdnsPacketWriter writer)262     protected abstract void writeData(MdnsPacketWriter writer) throws IOException;
263 
264     /** Gets the status of the record. */
getStatus(final long now)265     public int getStatus(final long now) {
266         if (receiptTimeMillis == RECEIPT_TIME_NOT_SENT) {
267             return STATUS_OK;
268         }
269         final long age = now - receiptTimeMillis;
270         if (age > ttlMillis) {
271             return STATUS_EXPIRED;
272         }
273         if (age > (ttlMillis / 2)) {
274             return STATUS_NEEDS_REFRESH;
275         }
276         return STATUS_OK;
277     }
278 
279     @Override
equals(@ullable Object other)280     public boolean equals(@Nullable Object other) {
281         if (!(other instanceof MdnsRecord)) {
282             return false;
283         }
284 
285         MdnsRecord otherRecord = (MdnsRecord) other;
286 
287         return MdnsUtils.equalsDnsLabelIgnoreDnsCase(name, otherRecord.name) && (type
288                 == otherRecord.type);
289     }
290 
291     @Override
hashCode()292     public int hashCode() {
293         return Objects.hash(Arrays.hashCode(MdnsUtils.toDnsLabelsLowerCase(name)), type);
294     }
295 
296     /**
297      * Returns an opaque object that uniquely identifies this record through a combination of its
298      * type
299      * and name. Suitable for use as a key in caches.
300      */
getKey()301     public final Object getKey() {
302         if (key == null) {
303             key = new Key(type, name);
304         }
305         return key;
306     }
307 
308     private static final class Key {
309         private final int recordType;
310         private final String[] recordName;
311 
Key(int recordType, String[] recordName)312         public Key(int recordType, String[] recordName) {
313             this.recordType = recordType;
314             this.recordName = MdnsUtils.toDnsLabelsLowerCase(recordName);
315         }
316 
317         @Override
equals(@ullable Object other)318         public boolean equals(@Nullable Object other) {
319             if (this == other) {
320                 return true;
321             }
322             if (!(other instanceof Key)) {
323                 return false;
324             }
325 
326             Key otherKey = (Key) other;
327 
328             return (recordType == otherKey.recordType) && Arrays.equals(recordName,
329                     otherKey.recordName);
330         }
331 
332         @Override
hashCode()333         public int hashCode() {
334             return (recordType * 31) + Arrays.hashCode(recordName);
335         }
336     }
337 }
338