1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.apf;
18 
19 import static android.net.apf.BaseApfGenerator.MemorySlot;
20 import static android.net.apf.BaseApfGenerator.Register.R0;
21 import static android.net.apf.BaseApfGenerator.Register.R1;
22 
23 import static com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN;
24 import static com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN;
25 
26 import android.annotation.NonNull;
27 
28 /**
29  * Utility class that generates generating APF filters for DNS packets.
30  */
31 public class DnsUtils {
32 
33     /** Length of the DNS header. */
34     private static final int DNS_HEADER_LEN = 12;
35     /** Offset of the qdcount field within the DNS header. */
36     private static final int DNS_QDCOUNT_OFFSET = 4;
37 
38     // Static labels
39     private static final String LABEL_START_MATCH = "start_match";
40     private static final String LABEL_PARSE_DNS_LABEL = "parse_dns_label";
41     private static final String LABEL_FIND_NEXT_DNS_QUESTION = "find_next_dns_question";
42 
43     // Length of the pointers used by compressed names.
44     private static final int LABEL_SIZE = Byte.BYTES;
45     private static final int POINTER_SIZE = Short.BYTES;
46     private static final int QUESTION_HEADER_SIZE = Short.BYTES + Short.BYTES;
47     private static final int LABEL_AND_QUESTION_HEADER_SIZE = LABEL_SIZE + QUESTION_HEADER_SIZE;
48     private static final int POINTER_AND_QUESTION_HEADER_SIZE = POINTER_SIZE + QUESTION_HEADER_SIZE;
49 
50     /** Memory slot that stores the offset within the packet of the DNS header. */
51     private static final MemorySlot SLOT_DNS_HEADER_OFFSET = MemorySlot.SLOT_1;
52     /** Memory slot that stores the current parsing offset. */
53     private static final MemorySlot SLOT_CURRENT_PARSE_OFFSET = MemorySlot.SLOT_2;
54     /**
55      * Memory slot that stores the offset after the current question, if the code is currently
56      * parsing a pointer, or 0 if it is not.
57      */
58     private static final MemorySlot SLOT_AFTER_POINTER_OFFSET = MemorySlot.SLOT_3;
59     /**
60      * Contains qdcount remaining, as a negative number. For example, will be -1 when starting to
61      * parse a DNS packet with one question in it. It's stored as a negative number because adding 1
62      * is much easier than subtracting 1 (which can't be done just by adding -1, because that just
63      * adds 254).
64      */
65     private static final MemorySlot SLOT_NEGATIVE_QDCOUNT_REMAINING = MemorySlot.SLOT_4;
66     /** Memory slot used by the jump table. */
67     private static final MemorySlot SLOT_RETURN_VALUE_INDEX = MemorySlot.SLOT_5;
68 
69     /**
70      * APF function: parse_dns_label
71      *
72      * Parses a label potentially containing a pointer, and calculates the label length and the
73      * offset of the label data.
74      *
75      * Inputs:
76      * - m[SLOT_DNS_HEADER_OFFSET]: offset of DNS header
77      * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset
78      * - m[SLOT_AFTER_POINTER_OFFSET]: offset after the question (e.g., offset of the next question,
79      *        or offset of the answer section) if a pointer is being chased, 0 otherwise
80      * - m[SLOT_RETURN_VALUE_INDEX]: index into return jump table
81      *
82      * Outputs:
83      * - R1: label length
84      * - m[SLOT_CURRENT_PARSE_OFFSET]: offset of label text
85      */
genParseDnsLabel(ApfV4Generator gen, JumpTable jumpTable)86     private static void genParseDnsLabel(ApfV4Generator gen, JumpTable jumpTable) throws Exception {
87         final String labelParseDnsLabelReal = "parse_dns_label_real";
88         final String labelPointerOffsetStored = "pointer_offset_stored";
89 
90         /**
91          * :parse_dns_label
92          * // Load parsing offset.
93          * LDM R1, 2                        // R1 = parsing offset. (All indexed loads use R1.)
94          */
95         gen.defineLabel(LABEL_PARSE_DNS_LABEL);
96         gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET);
97 
98 
99         /**
100          * // Check that we’re in the DNS packet, i.e., that R1 >= m[SLOT_DNS_HEADER_OFFSET].
101          * LDM R0, 1                        // R0 = DNS header offset
102          * JGT R0, R1, DROP                 // Bad pointer. Drop.
103          */
104         gen.addLoadFromMemory(R0, SLOT_DNS_HEADER_OFFSET);
105         gen.addJumpIfR0GreaterThanR1(ApfV4Generator.DROP_LABEL);
106 
107         /**
108          * // Now parse the label.
109          * LDBX R0, [R1]                    // R0 = label length, R1 = parsing offset
110          * AND R0, 0xc0                     // Is this a pointer?
111          *
112          * JEQ R0, 0, :parse_dns_label_real
113          */
114         gen.addLoad8Indexed(R0, 0);
115         gen.addAnd(0xc0);
116         gen.addJumpIfR0Equals(0, labelParseDnsLabelReal);
117 
118 
119         /**
120          * // If we’re not already chasing a pointer, store offset after pointer into
121          * // m[SLOT_AFTER_POINTER_OFFSET].
122          * LDM R0, 3                        // R0 = previous offset after pointer
123          * JNE 0, :pointer_offset_stored
124          * MOV R0, R1                       // R0 = R1
125          * ADD R0, 6                        // R0 = offset after pointer and record
126          * STM R0, 3                        // Store offset after pointer
127          */
128         gen.addLoadFromMemory(R0, SLOT_AFTER_POINTER_OFFSET);
129         gen.addJumpIfR0NotEquals(0, labelPointerOffsetStored);
130         gen.addMove(R0);
131         gen.addAdd(POINTER_AND_QUESTION_HEADER_SIZE);
132         gen.addStoreToMemory(SLOT_AFTER_POINTER_OFFSET, R0);
133 
134         /**
135          * :pointer_offset_stored
136          * LDHX R0, [R1]                    // R0 = 2-byte pointer value
137          * AND R0, 0x3ff                    // R0 = pointer destination offset (from DNS header)
138          * LDM R1, 1                        // R1 = offset in packet of DNS header
139          * ADD R0, R1                       // R0 = pointer destination offset
140          * LDM R1, 2                        // R1 = current parsing offset
141          * JEQ R0, R1, DROP                 // Drop if pointer points here...
142          * JGT R0, R1, DROP                 // ... or after here (must point backwards)
143          * STM R0, 2                        // Set next parsing offset to pointer destination
144          */
145         gen.defineLabel(labelPointerOffsetStored);
146         gen.addLoad16Indexed(R0, 0);
147         gen.addAnd(0x3ff);
148         gen.addLoadFromMemory(R1, SLOT_DNS_HEADER_OFFSET);
149         gen.addAddR1ToR0();
150         gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET);
151         gen.addJumpIfR0EqualsR1(ApfV4Generator.DROP_LABEL);
152         gen.addJumpIfR0GreaterThanR1(ApfV4Generator.DROP_LABEL);
153         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
154 
155         /** // Pointer chased. Parse starting from the pointer destination (which may also be a
156          * pointer).
157          * JMP :parse_dns_label
158          */
159         gen.addJump(LABEL_PARSE_DNS_LABEL);
160 
161         /**
162          * :parse_real_label
163          * // This is where the real (non-pointer) label starts.
164          * // Load label length into R1, and return to caller.
165          * // m[SLOT_CURRENT_PARSE_OFFSET] already contains label offset.
166          * LDHX R1, [R1]                    // R1 = label length
167          */
168         gen.defineLabel(labelParseDnsLabelReal);
169         gen.addLoad8Indexed(R1, 0);
170 
171         /** // Return
172          * LDM R0, 10
173          * JMP :jump_table
174          */
175         gen.addLoadFromMemory(R0, SLOT_RETURN_VALUE_INDEX);
176         gen.addJump(jumpTable.getStartLabel());
177     }
178 
179     /**
180      * APF function: find_next_dns_question
181      *
182      * Finds the next question in the question section, or drops the packet if there is none.
183      *
184      * Inputs:
185      * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset
186      * - m[SLOT_AFTER_POINTER_OFFSET]: offset after first pointer in name, or 0 if not chasing a
187      *           pointer
188      * - m[SLOT_NEGATIVE_QDCOUNT_REMAINING]: qdcount remaining, as a negative number. This is
189      *           because adding 1 is much easier than subtracting 1 (which can't be done just by
190      *           adding -1, because that just adds 254)
191      * - m[SLOT_RETURN_VALUE_INDEX]: index into return jump table
192      *
193      * Outputs:
194      * None
195      */
genFindNextDnsQuestion(ApfV4Generator gen, JumpTable jumpTable)196     private static void genFindNextDnsQuestion(ApfV4Generator gen, JumpTable jumpTable)
197             throws Exception {
198         final String labelFindNextDnsQuestionFollow = "find_next_dns_question_follow";
199         final String labelFindNextDnsQuestionLabel = "find_next_dns_question_label";
200         final String labelFindNextDnsQuestionLoop = "find_next_dns_question_loop";
201         final String labelFindNextDnsQuestionNoPointer = "find_next_dns_question_no_pointer";
202         final String labelFindNextDnsQuestionReturn = "find_next_dns_question_return";
203 
204         // Function entry point.
205         gen.defineLabel(LABEL_FIND_NEXT_DNS_QUESTION);
206 
207         // Are we chasing a pointer?
208         gen.addLoadFromMemory(R0, SLOT_AFTER_POINTER_OFFSET);
209         gen.addJumpIfR0Equals(0, labelFindNextDnsQuestionFollow);
210 
211         // If so, offset after the pointer and question is stored in m[SLOT_AFTER_POINTER_OFFSET].
212         // Move parsing offset there, clear m[SLOT_AFTER_POINTER_OFFSET], and return.
213         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
214         gen.addLoadImmediate(R0, 0);
215         gen.addStoreToMemory(SLOT_AFTER_POINTER_OFFSET, R0);
216         gen.addJump(labelFindNextDnsQuestionReturn);
217 
218         // We weren't chasing a pointer. Loop, following the label chain, until we reach a
219         // zero-length label or a pointer. At the beginning of the loop, the current parsing offset
220         // is m[SLOT_CURRENT_PARSE_OFFSET]. Move it to R1 and keep it in R1 throughout the loop.
221         gen.defineLabel(labelFindNextDnsQuestionFollow);
222         gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET);
223 
224         // Load label length.
225         gen.defineLabel(labelFindNextDnsQuestionLoop);
226         gen.addLoad8Indexed(R0, 0);
227         // Is it a pointer?
228         gen.addAnd(0xc0);
229         gen.addJumpIfR0Equals(0, labelFindNextDnsQuestionNoPointer);
230         // It's a pointer. Skip the pointer and question, and return.
231         gen.addLoadImmediate(R0, POINTER_AND_QUESTION_HEADER_SIZE);
232         gen.addAddR1ToR0();
233         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
234         gen.addJump(labelFindNextDnsQuestionReturn);
235 
236         // R1 still contains parsing offset.
237         gen.defineLabel(labelFindNextDnsQuestionNoPointer);
238         gen.addLoad8Indexed(R0, 0);
239 
240         // Zero-length label? We're done.
241         // Skip the label (1 byte) and query (2 bytes qtype, 2 bytes qclass) and return.
242         gen.addJumpIfR0NotEquals(0, labelFindNextDnsQuestionLabel);
243         gen.addLoadImmediate(R0, LABEL_AND_QUESTION_HEADER_SIZE);
244         gen.addAddR1ToR0();
245         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
246         gen.addJump(labelFindNextDnsQuestionReturn);
247 
248         // Non-zero length label. Consume it and continue.
249         gen.defineLabel(labelFindNextDnsQuestionLabel);
250         gen.addAdd(1);
251         gen.addAddR1ToR0();
252         gen.addMove(R1);
253         gen.addJump(labelFindNextDnsQuestionLoop);
254 
255         gen.defineLabel(labelFindNextDnsQuestionReturn);
256 
257         // Is this the last question? If so, drop.
258         gen.addLoadFromMemory(R0, SLOT_NEGATIVE_QDCOUNT_REMAINING);
259         gen.addAdd(1);
260         gen.addStoreToMemory(SLOT_NEGATIVE_QDCOUNT_REMAINING, R0);
261         gen.addJumpIfR0Equals(0, ApfV4Generator.DROP_LABEL);
262 
263         // If not, return.
264         gen.addJump(jumpTable.getStartLabel());
265     }
266 
267     /** @return jump label that points to the start of a DNS label's parsing code. */
getStartMatchLabel(int labelIndex)268     private static String getStartMatchLabel(int labelIndex) {
269         return "dns_parse_" + labelIndex;
270     }
271 
272     /** @return jump label used while parsing the specified DNS label. */
getPostMatchJumpTargetForLabel(int labelIndex)273     private static String getPostMatchJumpTargetForLabel(int labelIndex) {
274         return "dns_parsed_" + labelIndex;
275     }
276 
277     /** @return jump label used when the match for the specified DNS label fails. */
getNoMatchLabel(int labelIndex)278     private static String getNoMatchLabel(int labelIndex) {
279         return "dns_nomatch_" + labelIndex;
280     }
281 
addMatchLabel(@onNull ApfV4Generator gen, @NonNull JumpTable jumpTable, int labelIndex, @NonNull String label, @NonNull String nextLabel)282     private static void addMatchLabel(@NonNull ApfV4Generator gen, @NonNull JumpTable jumpTable,
283             int labelIndex, @NonNull String label, @NonNull String nextLabel) throws Exception {
284         final String parsedLabel = getPostMatchJumpTargetForLabel(labelIndex);
285         final String noMatchLabel = getNoMatchLabel(labelIndex);
286         gen.defineLabel(getStartMatchLabel(labelIndex));
287 
288         // Store return address.
289         gen.addLoadImmediate(R0, jumpTable.getIndex(parsedLabel));
290         gen.addStoreToMemory(SLOT_RETURN_VALUE_INDEX, R0);
291 
292         // Call the parse_label function.
293         gen.addJump(LABEL_PARSE_DNS_LABEL);
294 
295         gen.defineLabel(parsedLabel);
296 
297         // If label length is 0, this is the end of the name and the match failed.
298         gen.addSwap(); // Move label length from R1 to R0
299         gen.addJumpIfR0Equals(0, noMatchLabel);
300 
301         // Label parsed, check it matches what we're looking for.
302         gen.addJumpIfR0NotEquals(label.length(), noMatchLabel);
303         gen.addLoadFromMemory(R0, SLOT_CURRENT_PARSE_OFFSET);
304         gen.addAdd(1);
305         gen.addJumpIfBytesAtR0NotEqual(label.getBytes(), noMatchLabel);
306 
307         // Prep offset of next label.
308         gen.addAdd(label.length());
309         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
310 
311         // Match, go to next label.
312         gen.addJump(nextLabel);
313 
314         // Match failed. Go to next name, and restart from the first match.
315         gen.defineLabel(noMatchLabel);
316         gen.addLoadImmediate(R1, jumpTable.getIndex(LABEL_START_MATCH));
317         gen.addStoreToMemory(SLOT_RETURN_VALUE_INDEX, R1);
318         gen.addJump(LABEL_FIND_NEXT_DNS_QUESTION);
319     }
320 
321     /**
322      * Generates a filter that accepts DNS packet that ask for the specified name.
323      *
324      * The filter supports compressed DNS names and scanning through multiple questions in the same
325      * packet, e.g., as used by MDNS. However, it currently only supports one DNS name.
326      *
327      * Limitations:
328      * <ul>
329      * <li>Filter size is just under 300 bytes for a typical question.
330      * <li>Because the bytecode extensively uses backwards jumps, it can hit the APF interpreter
331      *   instruction limit. This limit causes the APF interpreter to accept the packet once it has
332      *   executed a number of instructions equal to the program length in bytes.
333      *   A program that consists *only* of this filter will be able to execute just under 300
334      *   instructions, and will be able to correctly drop packets with two questions but not three
335      *   questions. In a real APF setup, there will be other code (e.g., RA filtering) which counts
336      *   against the limit, so the filter should be able to parse packets with more questions.
337      * <li>Matches are case-sensitive. This is due to the use of JNEBS to match DNS labels and is
338      *   likely impossible to overcome without interpreter changes.
339      * </ul>
340      *
341      * TODO:
342      * <ul>
343      * <li>Add unit tests for the parse_dns_label and find_next_dns_question functions.
344      * <li>Support accepting more than one name.
345      * <li>For devices where power saving is a priority (e.g., flat panel TVs), add support for
346      *   dropping packets with more than X queries, to ensure the filter will drop the packet rather
347      *   than hit the instruction limit.
348      * </ul>
349      */
generateFilter(ApfV4Generator gen, String[] labels)350     public static void generateFilter(ApfV4Generator gen, String[] labels) throws Exception {
351         final int etherPlusUdpLen = ETHER_HEADER_LEN + UDP_HEADER_LEN;
352 
353         final String labelJumpTable = "jump_table";
354 
355         // Initialize parsing
356         /**
357          * - R1: length of IP header.
358          * - m[SLOT_DNS_HEADER_OFFSET]: offset of DNS header
359          * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset (start of question section)
360          * - m[SLOT_AFTER_POINTER_OFFSET]: offset after first pointer in name, must be 0 when
361          *                                 starting a new name
362          * - m[SLOT_NEGATIVE_QDCOUNT_REMAINING]: negative qdcount
363          */
364         // Move IP header length to R0 and use it to find the DNS header offset.
365         // TODO: this uses R1 for consistency with ApfFilter#generateMdnsFilterLocked. Evaluate
366         // using R0 instead.
367         gen.addMove(R0);
368         gen.addAdd(etherPlusUdpLen);
369         gen.addStoreToMemory(SLOT_DNS_HEADER_OFFSET, R0);
370 
371         gen.addAdd(DNS_QDCOUNT_OFFSET);
372         gen.addMove(R1);
373         gen.addLoad16Indexed(R1, 0);
374         gen.addNeg(R1);
375         gen.addStoreToMemory(SLOT_NEGATIVE_QDCOUNT_REMAINING, R1);
376 
377         gen.addAdd(DNS_HEADER_LEN - DNS_QDCOUNT_OFFSET);
378         gen.addStoreToMemory(SLOT_CURRENT_PARSE_OFFSET, R0);
379 
380         gen.addLoadImmediate(R0, 0);
381         gen.addStoreToMemory(SLOT_AFTER_POINTER_OFFSET, R0);
382 
383         gen.addJump(LABEL_START_MATCH);
384 
385         // Create JumpTable but
386         final JumpTable table = new JumpTable(labelJumpTable, SLOT_RETURN_VALUE_INDEX);
387 
388         // Generate bytecode for parse_label function.
389         genParseDnsLabel(gen, table);
390         genFindNextDnsQuestion(gen, table);
391 
392         // Populate jump table. Should be before the code that calls to it (i.e., the addMatchLabel
393         // calls below) because otherwise all the jumps are backwards, and backwards jumps are more
394         // expensive (5 bytes of bytecode)
395         for (int i = 0; i < labels.length; i++) {
396             table.addLabel(getPostMatchJumpTargetForLabel(i));
397         }
398         table.addLabel(LABEL_START_MATCH);
399         table.generate(gen);
400 
401         // Add match statements for name.
402         gen.defineLabel(LABEL_START_MATCH);
403         for (int i = 0; i < labels.length; i++) {
404             final String nextLabel = (i == labels.length - 1)
405                     ? ApfV4Generator.PASS_LABEL
406                     : getStartMatchLabel(i + 1);
407             addMatchLabel(gen, table, i, labels[i], nextLabel);
408         }
409         gen.addJump(ApfV4Generator.DROP_LABEL);
410     }
411 
DnsUtils()412     private DnsUtils() {
413     }
414 }
415