1 /*
2  * Copyright 2024, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "apf_interpreter.h"
18 
19 #include <string.h>  /* For memcmp, memcpy, memset */
20 
21 #if __GNUC__ >= 7 || __clang__
22 #define FALLTHROUGH __attribute__((fallthrough))
23 #else
24 #define FALLTHROUGH
25 #endif
26 
27 typedef enum { False, True } Boolean;
28 
29 /* Begin include of apf_defs.h */
30 typedef int8_t s8;
31 typedef int16_t s16;
32 typedef int32_t s32;
33 
34 typedef uint8_t u8;
35 typedef uint16_t u16;
36 typedef uint32_t u32;
37 
38 typedef enum {
39   error_program = -2,
40   error_packet = -1,
41   nomatch = False,
42   match = True
43 } match_result_type;
44 
45 #define ETH_P_IP	0x0800
46 #define ETH_P_IPV6	0x86DD
47 
48 #define ETH_HLEN	14
49 #define IPV4_HLEN	20
50 #define IPV6_HLEN	40
51 #define TCP_HLEN	20
52 #define UDP_HLEN	8
53 
54 #define FUNC(x) x; x
55 /* End include of apf_defs.h */
56 /* Begin include of apf.h */
57 /*
58  * Copyright 2024, The Android Open Source Project
59  *
60  * Licensed under the Apache License, Version 2.0 (the "License");
61  * you may not use this file except in compliance with the License.
62  * You may obtain a copy of the License at
63  *
64  * http://www.apache.org/licenses/LICENSE-2.0
65  *
66  * Unless required by applicable law or agreed to in writing, software
67  * distributed under the License is distributed on an "AS IS" BASIS,
68  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69  * See the License for the specific language governing permissions and
70  * limitations under the License.
71  */
72 
73 #ifndef ANDROID_APF_APF_H
74 #define ANDROID_APF_APF_H
75 
76 /* A brief overview of APF:
77  *
78  * APF machine is composed of:
79  *  1. A read-only program consisting of bytecodes as described below.
80  *  2. Two 32-bit registers, called R0 and R1.
81  *  3. Sixteen 32-bit temporary memory slots (cleared between packets).
82  *  4. A read-only packet.
83  *  5. An optional read-write transmit buffer.
84  * The program is executed by the interpreter below and parses the packet
85  * to determine if the application processor (AP) should be woken up to
86  * handle the packet or if it can be dropped.  The program may also choose
87  * to allocate/transmit/deallocate the transmit buffer.
88  *
89  * APF bytecode description:
90  *
91  * The APF interpreter uses big-endian byte order for loads from the packet
92  * and for storing immediates in instructions.
93  *
94  * Each instruction starts with a byte composed of:
95  *  Top 5 bits form "opcode" field, see *_OPCODE defines below.
96  *  Next 2 bits form "size field", which indicates the length of an immediate
97  *  value which follows the first byte.  Values in this field:
98  *                 0 => immediate value is 0 and no bytes follow.
99  *                 1 => immediate value is 1 byte big.
100  *                 2 => immediate value is 2 bytes big.
101  *                 3 => immediate value is 4 bytes big.
102  *  Bottom bit forms "register" field, which (usually) indicates which register
103  *  this instruction operates on.
104  *
105  *  There are four main categories of instructions:
106  *  Load instructions
107  *    These instructions load byte(s) of the packet into a register.
108  *    They load either 1, 2 or 4 bytes, as determined by the "opcode" field.
109  *    They load into the register specified by the "register" field.
110  *    The immediate value that follows the first byte of the instruction is
111  *    the byte offset from the beginning of the packet to load from.
112  *    There are "indexing" loads which add the value in R1 to the byte offset
113  *    to load from. The "opcode" field determines which loads are "indexing".
114  *  Arithmetic instructions
115  *    These instructions perform simple operations, like addition, on register
116  *    values. The result of these instructions is always written into R0. One
117  *    argument of the arithmetic operation is R0's value. The other argument
118  *    of the arithmetic operation is determined by the "register" field:
119  *            If the "register" field is 0 then the immediate value following
120  *            the first byte of the instruction is used as the other argument
121  *            to the arithmetic operation.
122  *            If the "register" field is 1 then R1's value is used as the other
123  *            argument to the arithmetic operation.
124  *  Conditional jump instructions
125  *    These instructions compare register R0's value with another value, and if
126  *    the comparison succeeds, jump (i.e. adjust the program counter). The
127  *    immediate value that follows the first byte of the instruction
128  *    represents the jump target offset, i.e. the value added to the program
129  *    counter if the comparison succeeds. The other value compared is
130  *    determined by the "register" field:
131  *            If the "register" field is 0 then another immediate value
132  *            follows the jump target offset. This immediate value is of the
133  *            same size as the jump target offset, and represents the value
134  *            to compare against.
135  *            If the "register" field is 1 then register R1's value is
136  *            compared against.
137  *    The type of comparison (e.g. equal to, greater than etc) is determined
138  *    by the "opcode" field. The comparison interprets both values being
139  *    compared as unsigned values.
140  *  Miscellaneous instructions
141  *    Instructions for:
142  *      - allocating/transmitting/deallocating transmit buffer
143  *      - building the transmit packet (copying bytes into it)
144  *      - read/writing data section
145  *
146  *  Miscellaneous details:
147  *
148  *  Pre-filled temporary memory slot values
149  *    When the APF program begins execution, six of the sixteen memory slots
150  *    are pre-filled by the interpreter with values that may be useful for
151  *    programs:
152  *      #0 to #7 are zero initialized.
153  *      Slot #8  is initialized with apf version (on APF >4).
154  *      Slot #9  this is slot #15 with greater resolution (1/16384ths of a second)
155  *      Slot #10 starts at zero, implicitly used as tx buffer output pointer.
156  *      Slot #11 contains the size (in bytes) of the APF program.
157  *      Slot #12 contains the total size of the APF program + data.
158  *      Slot #13 is filled with the IPv4 header length. This value is calculated
159  *               by loading the first byte of the IPv4 header and taking the
160  *               bottom 4 bits and multiplying their value by 4. This value is
161  *               set to zero if the first 4 bits after the link layer header are
162  *               not 4, indicating not IPv4.
163  *      Slot #14 is filled with size of the packet in bytes, including the
164  *               ethernet link-layer header.
165  *      Slot #15 is filled with the filter age in seconds. This is the number of
166  *               seconds since the host installed the program. This may
167  *               be used by filters that should have a particular lifetime. For
168  *               example, it can be used to rate-limit particular packets to one
169  *               every N seconds.
170  *  Special jump targets:
171  *    When an APF program executes a jump to the byte immediately after the last
172  *      byte of the progam (i.e., one byte past the end of the program), this
173  *      signals the program has completed and determined the packet should be
174  *      passed to the AP.
175  *    When an APF program executes a jump two bytes past the end of the program,
176  *      this signals the program has completed and determined the packet should
177  *      be dropped.
178  *  Jump if byte sequence doesn't match:
179  *    This is a special instruction to facilitate matching long sequences of
180  *    bytes in the packet. Initially it is encoded like a conditional jump
181  *    instruction with two exceptions:
182  *      The first byte of the instruction is always followed by two immediate
183  *        fields: The first immediate field is the jump target offset like other
184  *        conditional jump instructions. The second immediate field specifies the
185  *        number of bytes to compare.
186  *      These two immediate fields are followed by a sequence of bytes. These
187  *        bytes are compared with the bytes in the packet starting from the
188  *        position specified by the value of the register specified by the
189  *        "register" field of the instruction.
190  */
191 
192 /* Number of temporary memory slots, see ldm/stm instructions. */
193 #define MEMORY_ITEMS 16
194 /* Upon program execution, some temporary memory slots are prefilled: */
195 
196 typedef union {
197   struct {
198     u32 pad[8];               /* 0..7 */
199     u32 apf_version;          /* 8:  Initialized with apf_version() */
200     u32 filter_age_16384ths;  /* 9:  Age since filter installed in 1/16384 seconds. */
201     u32 tx_buf_offset;        /* 10: Offset in tx_buf where next byte will be written */
202     u32 program_size;         /* 11: Size of program (in bytes) */
203     u32 ram_len;              /* 12: Total size of program + data, ie. ram_len */
204     u32 ipv4_header_size;     /* 13: 4*([APF_FRAME_HEADER_SIZE]&15) */
205     u32 packet_size;          /* 14: Size of packet in bytes. */
206     u32 filter_age;           /* 15: Age since filter installed in seconds. */
207   } named;
208   u32 slot[MEMORY_ITEMS];
209 } memory_type;
210 
211 /* ---------------------------------------------------------------------------------------------- */
212 
213 /* Standard opcodes. */
214 
215 /* Unconditionally pass (if R=0) or drop (if R=1) packet and optionally increment counter.
216  * An optional non-zero unsigned immediate value can be provided to encode the counter number.
217  * The counter is located (-4 * counter number) bytes from the end of the data region.
218  * It is a U32 big-endian value and is always incremented by 1.
219  * This is more or less equivalent to: lddw R0, -4*N; add R0, 1; stdw R0, -4*N; {pass,drop}
220  * e.g. "pass", "pass 1", "drop", "drop 1"
221  */
222 #define PASSDROP_OPCODE 0
223 
224 #define LDB_OPCODE 1    /* Load 1 byte  from immediate offset, e.g. "ldb R0, [5]" */
225 #define LDH_OPCODE 2    /* Load 2 bytes from immediate offset, e.g. "ldh R0, [5]" */
226 #define LDW_OPCODE 3    /* Load 4 bytes from immediate offset, e.g. "ldw R0, [5]" */
227 #define LDBX_OPCODE 4   /* Load 1 byte  from immediate offset plus register, e.g. "ldbx R0, [5+R0]" */
228 #define LDHX_OPCODE 5   /* Load 2 bytes from immediate offset plus register, e.g. "ldhx R0, [5+R0]" */
229 #define LDWX_OPCODE 6   /* Load 4 bytes from immediate offset plus register, e.g. "ldwx R0, [5+R0]" */
230 #define ADD_OPCODE 7    /* Add, e.g. "add R0,5" */
231 #define MUL_OPCODE 8    /* Multiply, e.g. "mul R0,5" */
232 #define DIV_OPCODE 9    /* Divide, e.g. "div R0,5" */
233 #define AND_OPCODE 10   /* And, e.g. "and R0,5" */
234 #define OR_OPCODE 11    /* Or, e.g. "or R0,5" */
235 #define SH_OPCODE 12    /* Left shift, e.g. "sh R0, 5" or "sh R0, -5" (shifts right) */
236 #define LI_OPCODE 13    /* Load signed immediate, e.g. "li R0,5" */
237 #define JMP_OPCODE 14   /* Unconditional jump, e.g. "jmp label" */
238 #define JEQ_OPCODE 15   /* Compare equal and branch, e.g. "jeq R0,5,label" */
239 #define JNE_OPCODE 16   /* Compare not equal and branch, e.g. "jne R0,5,label" */
240 #define JGT_OPCODE 17   /* Compare greater than and branch, e.g. "jgt R0,5,label" */
241 #define JLT_OPCODE 18   /* Compare less than and branch, e.g. "jlt R0,5,label" */
242 #define JSET_OPCODE 19  /* Compare any bits set and branch, e.g. "jset R0,5,label" */
243 #define JBSMATCH_OPCODE 20 /* Compare byte sequence [R=0 not] equal, e.g. "jbsne R0,2,label,0x1122" */
244                            /* NOTE: Only APFv6+ implements R=1 'jbseq' version and multi match */
245                            /* imm1 is jmp target, imm2 is (cnt - 1) * 2048 + compare_len, */
246                            /* which is followed by cnt * compare_len bytes to compare against. */
247                            /* Warning: do not specify the same byte sequence multiple times. */
248 #define EXT_OPCODE 21   /* Immediate value is one of *_EXT_OPCODE */
249 #define LDDW_OPCODE 22  /* Load 4 bytes from data address (register + signed imm): "lddw R0, [5+R1]" */
250                         /* LDDW/STDW in APFv6+ *mode* load/store from counter specified in imm. */
251 #define STDW_OPCODE 23  /* Store 4 bytes to data address (register + signed imm): "stdw R0, [5+R1]" */
252 
253 /* Write 1, 2 or 4 byte immediate to the output buffer and auto-increment the output buffer pointer.
254  * Immediate length field specifies size of write.  R must be 0.  imm_len != 0.
255  * e.g. "write 5"
256  */
257 #define WRITE_OPCODE 24
258 
259 /* Copy bytes from input packet/APF program/data region to output buffer and
260  * auto-increment the output buffer pointer.
261  * Register bit is used to specify the source of data copy.
262  * R=0 means copy from packet.
263  * R=1 means copy from APF program/data region.
264  * The source offset is stored in imm1, copy length is stored in u8 imm2.
265  * e.g. "pktcopy 0, 16" or "datacopy 0, 16"
266  */
267 #define PKTDATACOPY_OPCODE 25
268 
269 #define JNSET_OPCODE 26 /* JSET with reverse condition (jump if no bits set) */
270 
271 /* ---------------------------------------------------------------------------------------------- */
272 
273 /* Extended opcodes. */
274 /* These all have an opcode of EXT_OPCODE and specify the actual opcode in the immediate field. */
275 
276 #define LDM_EXT_OPCODE 0   /* Load from temporary memory, e.g. "ldm R0,5" */
277   /* Values 0-15 represent loading the different temporary memory slots. */
278 #define STM_EXT_OPCODE 16  /* Store to temporary memory, e.g. "stm R0,5" */
279   /* Values 16-31 represent storing to the different temporary memory slots. */
280 #define NOT_EXT_OPCODE 32  /* Not, e.g. "not R0" */
281 #define NEG_EXT_OPCODE 33  /* Negate, e.g. "neg R0" */
282 #define SWAP_EXT_OPCODE 34 /* Swap, e.g. "swap R0,R1" */
283 #define MOV_EXT_OPCODE 35  /* Move, e.g. "move R0,R1" */
284 
285 /* Allocate writable output buffer.
286  * R=0: register R0 specifies the length
287  * R=1: length provided in u16 imm2
288  * e.g. "allocate R0" or "allocate 123"
289  * On failure automatically executes 'pass 3'
290  */
291 #define ALLOCATE_EXT_OPCODE 36
292 /* Transmit and deallocate the buffer (transmission can be delayed until the program
293  * terminates).  Length of buffer is the output buffer pointer (0 means discard).
294  * R=1 iff udp style L4 checksum
295  * u8 imm2 - ip header offset from start of buffer (255 for non-ip packets)
296  * u8 imm3 - offset from start of buffer to store L4 checksum (255 for no L4 checksum)
297  * u8 imm4 - offset from start of buffer to begin L4 checksum calculation (present iff imm3 != 255)
298  * u16 imm5 - partial checksum value to include in L4 checksum (present iff imm3 != 255)
299  * "e.g. transmit"
300  */
301 #define TRANSMIT_EXT_OPCODE 37
302 /* Write 1, 2 or 4 byte value from register to the output buffer and auto-increment the
303  * output buffer pointer.
304  * e.g. "ewrite1 r0" or "ewrite2 r1"
305  */
306 #define EWRITE1_EXT_OPCODE 38
307 #define EWRITE2_EXT_OPCODE 39
308 #define EWRITE4_EXT_OPCODE 40
309 
310 /* Copy bytes from input packet/APF program/data region to output buffer and
311  * auto-increment the output buffer pointer.
312  * Register bit is used to specify the source of data copy.
313  * R=0 means copy from packet.
314  * R=1 means copy from APF program/data region.
315  * The source offset is stored in R0, copy length is stored in u8 imm2 or R1.
316  * e.g. "epktcopy r0, 16", "edatacopy r0, 16", "epktcopy r0, r1", "edatacopy r0, r1"
317  */
318 #define EPKTDATACOPYIMM_EXT_OPCODE 41
319 #define EPKTDATACOPYR1_EXT_OPCODE 42
320 /* Jumps if the UDP payload content (starting at R0) does [not] match one
321  * of the specified QNAMEs in question records, applying case insensitivity.
322  * SAFE version PASSES corrupt packets, while the other one DROPS.
323  * R=0/1 meaning 'does not match'/'matches'
324  * R0: Offset to UDP payload content
325  * imm1: Extended opcode
326  * imm2: Jump label offset
327  * imm3(u8): Question type (PTR/SRV/TXT/A/AAAA)
328  * imm4(bytes): null terminated list of null terminated LV-encoded QNAMEs
329  * e.g.: "jdnsqeq R0,label,0xc,\002aa\005local\0\0", "jdnsqne R0,label,0xc,\002aa\005local\0\0"
330  */
331 #define JDNSQMATCH_EXT_OPCODE 43
332 #define JDNSQMATCHSAFE_EXT_OPCODE 45
333 /* Jumps if the UDP payload content (starting at R0) does [not] match one
334  * of the specified NAMEs in answers/authority/additional records, applying
335  * case insensitivity.
336  * SAFE version PASSES corrupt packets, while the other one DROPS.
337  * R=0/1 meaning 'does not match'/'matches'
338  * R0: Offset to UDP payload content
339  * imm1: Extended opcode
340  * imm2: Jump label offset
341  * imm3(bytes): null terminated list of null terminated LV-encoded NAMEs
342  * e.g.: "jdnsaeq R0,label,0xc,\002aa\005local\0\0", "jdnsane R0,label,0xc,\002aa\005local\0\0"
343  */
344 #define JDNSAMATCH_EXT_OPCODE 44
345 #define JDNSAMATCHSAFE_EXT_OPCODE 46
346 
347 /* Jump if register is [not] one of the list of values
348  * R bit - specifies the register (R0/R1) to test
349  * imm1: Extended opcode
350  * imm2: Jump label offset
351  * imm3(u8): top 5 bits - number 'n' of following u8/be16/be32 values - 2
352  *        middle 2 bits - 1..4 length of immediates - 1
353  *        bottom 1 bit  - =0 jmp if in set, =1 if not in set
354  * imm4(n * 1/2/3/4 bytes): the *UNIQUE* values to compare against
355  */
356 #define JONEOF_EXT_OPCODE 47
357 
358 /* Specify length of exception buffer, which is populated on abnormal program termination.
359  * imm1: Extended opcode
360  * imm2(u16): Length of exception buffer (located *immediately* after the program itself)
361  */
362 #define EXCEPTIONBUFFER_EXT_OPCODE 48
363 
364 /* This extended opcode is used to implement PKTDATACOPY_OPCODE */
365 #define PKTDATACOPYIMM_EXT_OPCODE 65536
366 
367 #define EXTRACT_OPCODE(i) (((i) >> 3) & 31)
368 #define EXTRACT_REGISTER(i) ((i) & 1)
369 #define EXTRACT_IMM_LENGTH(i) (((i) >> 1) & 3)
370 
371 #endif  /* ANDROID_APF_APF_H */
372 /* End include of apf.h */
373 /* Begin include of apf_utils.h */
read_be16(const u8 * buf)374 static u32 read_be16(const u8* buf) {
375     return buf[0] * 256u + buf[1];
376 }
377 
store_be16(u8 * const buf,const u16 v)378 static void store_be16(u8* const buf, const u16 v) {
379     buf[0] = (u8)(v >> 8);
380     buf[1] = (u8)v;
381 }
382 
uppercase(u8 c)383 static u8 uppercase(u8 c) {
384     return (c >= 'a') && (c <= 'z') ? c - ('a' - 'A') : c;
385 }
386 /* End include of apf_utils.h */
387 /* Begin include of apf_dns.h */
388 /**
389  * Compares a (Q)NAME starting at udp[*ofs] with the target name.
390  *
391  * @param needle - non-NULL - pointer to DNS encoded target name to match against.
392  *   example: [11]_googlecast[4]_tcp[5]local[0]  (where [11] is a byte with value 11)
393  * @param needle_bound - non-NULL - points at first invalid byte past needle.
394  * @param udp - non-NULL - pointer to the start of the UDP payload (DNS header).
395  * @param udp_len - length of the UDP payload.
396  * @param ofs - non-NULL - pointer to the offset of the beginning of the (Q)NAME.
397  *   On non-error return will be updated to point to the first unread offset,
398  *   ie. the next position after the (Q)NAME.
399  *
400  * @return 1 if matched, 0 if not matched, -1 if error in packet, -2 if error in program.
401  */
FUNC(match_result_type apf_internal_match_single_name (const u8 * needle,const u8 * const needle_bound,const u8 * const udp,const u32 udp_len,u32 * const ofs))402 FUNC(match_result_type apf_internal_match_single_name(const u8* needle,
403                                     const u8* const needle_bound,
404                                     const u8* const udp,
405                                     const u32 udp_len,
406                                     u32* const ofs)) {
407     u32 first_unread_offset = *ofs;
408     Boolean is_qname_match = True;
409     int lvl;
410 
411     /* DNS names are <= 255 characters including terminating 0, since >= 1 char + '.' per level => max. 127 levels */
412     for (lvl = 1; lvl <= 127; ++lvl) {
413         u8 v;
414         if (*ofs >= udp_len) return error_packet;
415         v = udp[(*ofs)++];
416         if (v >= 0xC0) { /* RFC 1035 4.1.4 - handle message compression */
417             u8 w;
418             u32 new_ofs;
419             if (*ofs >= udp_len) return error_packet;
420             w = udp[(*ofs)++];
421             if (*ofs > first_unread_offset) first_unread_offset = *ofs;
422             new_ofs = (v - 0xC0) * 256u + w;
423             if (new_ofs >= *ofs) return error_packet;  /* RFC 1035 4.1.4 allows only backward pointers */
424             *ofs = new_ofs;
425         } else if (v > 63) {
426             return error_packet;  /* RFC 1035 2.3.4 - label size is 1..63. */
427         } else if (v) {
428             u8 label_size = v;
429             if (*ofs + label_size > udp_len) return error_packet;
430             if (needle >= needle_bound) return error_program;
431             if (is_qname_match) {
432                 u8 len = *needle++;
433                 if (len == label_size) {
434                     if (needle + label_size > needle_bound) return error_program;
435                     while (label_size--) {
436                         u8 w = udp[(*ofs)++];
437                         is_qname_match &= (uppercase(w) == *needle++);
438                     }
439                 } else {
440                     if (len != 0xFF) is_qname_match = False;
441                     *ofs += label_size;
442                 }
443             } else {
444                 is_qname_match = False;
445                 *ofs += label_size;
446             }
447         } else { /* reached the end of the name */
448             if (first_unread_offset > *ofs) *ofs = first_unread_offset;
449             return (is_qname_match && *needle == 0) ? match : nomatch;
450         }
451     }
452     return error_packet;  /* too many dns domain name levels */
453 }
454 
455 /**
456  * Check if DNS packet contains any of the target names with the provided
457  * question_type.
458  *
459  * @param needles - non-NULL - pointer to DNS encoded target nameS to match against.
460  *   example: [3]foo[3]com[0][3]bar[3]net[0][0]  -- note ends with an extra NULL byte.
461  * @param needle_bound - non-NULL - points at first invalid byte past needles.
462  * @param udp - non-NULL - pointer to the start of the UDP payload (DNS header).
463  * @param udp_len - length of the UDP payload.
464  * @param question_type - question type to match against or -1 to match answers.
465  *
466  * @return 1 if matched, 0 if not matched, -1 if error in packet, -2 if error in program.
467  */
FUNC(match_result_type apf_internal_match_names (const u8 * needles,const u8 * const needle_bound,const u8 * const udp,const u32 udp_len,const int question_type))468 FUNC(match_result_type apf_internal_match_names(const u8* needles,
469                               const u8* const needle_bound,
470                               const u8* const udp,
471                               const u32 udp_len,
472                               const int question_type)) {
473     u32 num_questions, num_answers;
474     if (udp_len < 12) return error_packet;  /* lack of dns header */
475 
476     /* dns header: be16 tid, flags, num_{questions,answers,authority,additional} */
477     num_questions = read_be16(udp + 4);
478     num_answers = read_be16(udp + 6) + read_be16(udp + 8) + read_be16(udp + 10);
479 
480     /* loop until we hit final needle, which is a null byte */
481     while (True) {
482         u32 i, ofs = 12;  /* dns header is 12 bytes */
483         if (needles >= needle_bound) return error_program;
484         if (!*needles) return nomatch;  /* we've run out of needles without finding a match */
485         /* match questions */
486         for (i = 0; i < num_questions; ++i) {
487             match_result_type m = apf_internal_match_single_name(needles, needle_bound, udp, udp_len, &ofs);
488             int qtype;
489             if (m < nomatch) return m;
490             if (ofs + 2 > udp_len) return error_packet;
491             qtype = (int)read_be16(udp + ofs);
492             ofs += 4; /* skip be16 qtype & qclass */
493             if (question_type == -1) continue;
494             if (m == nomatch) continue;
495             if (qtype == 0xFF /* QTYPE_ANY */ || qtype == question_type) return match;
496         }
497         /* match answers */
498         if (question_type == -1) for (i = 0; i < num_answers; ++i) {
499             match_result_type m = apf_internal_match_single_name(needles, needle_bound, udp, udp_len, &ofs);
500             if (m < nomatch) return m;
501             ofs += 8; /* skip be16 type, class & be32 ttl */
502             if (ofs + 2 > udp_len) return error_packet;
503             ofs += 2 + read_be16(udp + ofs);  /* skip be16 rdata length field, plus length bytes */
504             if (m == match) return match;
505         }
506         /* move needles pointer to the next needle. */
507         do {
508             u8 len = *needles++;
509             if (len == 0xFF) continue;
510             if (len > 63) return error_program;
511             needles += len;
512             if (needles >= needle_bound) return error_program;
513         } while (*needles);
514         needles++;  /* skip the NULL byte at the end of *a* DNS name */
515     }
516 }
517 /* End include of apf_dns.h */
518 /* Begin include of apf_checksum.h */
519 /**
520  * Calculate big endian 16-bit sum of a buffer (max 128kB),
521  * then fold and negate it, producing a 16-bit result in [0..FFFE].
522  */
FUNC(u16 apf_internal_calc_csum (u32 sum,const u8 * const buf,const s32 len))523 FUNC(u16 apf_internal_calc_csum(u32 sum, const u8* const buf, const s32 len)) {
524     u16 csum;
525     s32 i;
526     for (i = 0; i < len; ++i) sum += buf[i] * ((i & 1) ? 1u : 256u);
527 
528     sum = (sum & 0xFFFF) + (sum >> 16);  /* max after this is 1FFFE */
529     csum = sum + (sum >> 16);
530     return ~csum;  /* assuming sum > 0 on input, this is in [0..FFFE] */
531 }
532 
fix_udp_csum(u16 csum)533 static u16 fix_udp_csum(u16 csum) {
534     return csum ? csum : 0xFFFF;
535 }
536 
537 /**
538  * Calculate and store packet checksums and return dscp.
539  *
540  * @param pkt - pointer to the very start of the to-be-transmitted packet,
541  *              ie. the start of the ethernet header (if one is present)
542  *     WARNING: at minimum 266 bytes of buffer pointed to by 'pkt' pointer
543  *              *MUST* be writable.
544  * (IPv4 header checksum is a 2 byte value, 10 bytes after ip_ofs,
545  * which has a maximum value of 254.  Thus 254[ip_ofs] + 10 + 2[u16] = 266)
546  *
547  * @param len - length of the packet (this may be < 266).
548  * @param ip_ofs - offset from beginning of pkt to IPv4 or IPv6 header:
549  *                 IP version detected based on top nibble of this byte,
550  *                 for IPv4 we will calculate and store IP header checksum,
551  *                 but only for the first 20 bytes of the header,
552  *                 prior to calling this the IPv4 header checksum field
553  *                 must be initialized to the partial checksum of the IPv4
554  *                 options (0 if none)
555  *                 255 means there is no IP header (for example ARP)
556  *                 DSCP will be retrieved from this IP header (0 if none).
557  * @param partial_csum - additional value to include in L4 checksum
558  * @param csum_start - offset from beginning of pkt to begin L4 checksum
559  *                     calculation (until end of pkt specified by len)
560  * @param csum_ofs - offset from beginning of pkt to store L4 checksum
561  *                   255 means do not calculate/store L4 checksum
562  * @param udp - True iff we should generate a UDP style L4 checksum (0 -> 0xFFFF)
563  *
564  * @return 6-bit DSCP value [0..63], garbage on parse error.
565  */
FUNC(int apf_internal_csum_and_return_dscp (u8 * const pkt,const s32 len,const u8 ip_ofs,const u16 partial_csum,const u8 csum_start,const u8 csum_ofs,const Boolean udp))566 FUNC(int apf_internal_csum_and_return_dscp(u8* const pkt, const s32 len, const u8 ip_ofs,
567   const u16 partial_csum, const u8 csum_start, const u8 csum_ofs, const Boolean udp)) {
568     if (csum_ofs < 255) {
569         /* note that apf_internal_calc_csum() treats negative lengths as zero */
570         u32 csum = apf_internal_calc_csum(partial_csum, pkt + csum_start, len - csum_start);
571         if (udp) csum = fix_udp_csum(csum);
572         store_be16(pkt + csum_ofs, csum);
573     }
574     if (ip_ofs < 255) {
575         u8 ip = pkt[ip_ofs] >> 4;
576         if (ip == 4) {
577             store_be16(pkt + ip_ofs + 10, apf_internal_calc_csum(0, pkt + ip_ofs, IPV4_HLEN));
578             return pkt[ip_ofs + 1] >> 2;  /* DSCP */
579         } else if (ip == 6) {
580             return (read_be16(pkt + ip_ofs) >> 6) & 0x3F;  /* DSCP */
581         }
582     }
583     return 0;
584 }
585 /* End include of apf_checksum.h */
586 
587 /* User hook for interpreter debug tracing. */
588 #ifdef APF_TRACE_HOOK
589 extern void APF_TRACE_HOOK(u32 pc, const u32* regs, const u8* program,
590                            u32 program_len, const u8 *packet, u32 packet_len,
591                            const u32* memory, u32 ram_len);
592 #else
593 #define APF_TRACE_HOOK(pc, regs, program, program_len, packet, packet_len, memory, memory_len) \
594     do { /* nop*/                                                                              \
595     } while (0)
596 #endif
597 
598 /* Return code indicating "packet" should accepted. */
599 #define PASS 1
600 /* Return code indicating "packet" should be accepted (and something unexpected happened). */
601 #define EXCEPTION 2
602 /* Return code indicating "packet" should be dropped. */
603 #define DROP 0
604 /* Verify an internal condition and accept packet if it fails. */
605 #define ASSERT_RETURN(c) if (!(c)) return EXCEPTION
606 /* If "c" is of an unsigned type, generate a compile warning that gets promoted to an error. */
607 /* This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding */
608 /* superfluous ">= 0" with unsigned expressions generates compile warnings. */
609 #define ENFORCE_UNSIGNED(c) ((c)==(u32)(c))
610 
apf_version(void)611 u32 apf_version(void) {
612     return 20240510;
613 }
614 
615 typedef struct {
616     /* Note: the following 4 fields take up exactly 8 bytes. */
617     u16 except_buf_sz; /* Length of the exception buffer (at program_len offset) */
618     u8 ptr_size;       /* sizeof(void*) */
619     u8 v6;             /* Set to 1 by first jmpdata (APFv6+) instruction */
620     u32 pc;            /* Program counter. */
621     /* All the pointers should be next to each other for better struct packing. */
622     /* We are at offset 8, so even 64-bit pointers will not need extra padding. */
623     void *caller_ctx;  /* Passed in to interpreter, passed through to alloc/transmit. */
624     u8* tx_buf;        /* The output buffer pointer */
625     u8* program;       /* Pointer to program/data buffer */
626     const u8* packet;  /* Pointer to input packet buffer */
627     /* Order fields in order of decreasing size */
628     u32 tx_buf_len;    /* The length of the output buffer */
629     u32 program_len;   /* Length of the program */
630     u32 ram_len;       /* Length of the entire apf program/data region */
631     u32 packet_len;    /* Length of the input packet buffer */
632     u32 R[2];          /* Register values. */
633     memory_type mem;   /* Memory slot values.  (array of u32s) */
634     /* Note: any extra u16s go here, then u8s */
635 } apf_context;
636 
FUNC(int apf_internal_do_transmit_buffer (apf_context * ctx,u32 pkt_len,u8 dscp))637 FUNC(int apf_internal_do_transmit_buffer(apf_context* ctx, u32 pkt_len, u8 dscp)) {
638     int ret = apf_transmit_buffer(ctx->caller_ctx, ctx->tx_buf, pkt_len, dscp);
639     ctx->tx_buf = NULL;
640     ctx->tx_buf_len = 0;
641     return ret;
642 }
643 
do_discard_buffer(apf_context * ctx)644 static int do_discard_buffer(apf_context* ctx) {
645     return apf_internal_do_transmit_buffer(ctx, 0 /* pkt_len */, 0 /* dscp */);
646 }
647 
648 #define DECODE_U8() (ctx->program[ctx->pc++])
649 
decode_be16(apf_context * ctx)650 static u16 decode_be16(apf_context* ctx) {
651     u16 v = DECODE_U8();
652     v <<= 8;
653     v |= DECODE_U8();
654     return v;
655 }
656 
657 /* Decode an immediate, lengths [0..4] all work, does not do range checking. */
658 /* But note that program is at least 20 bytes shorter than ram, so first few */
659 /* immediates can always be safely decoded without exceeding ram buffer. */
decode_imm(apf_context * ctx,u32 length)660 static u32 decode_imm(apf_context* ctx, u32 length) {
661     u32 i, v = 0;
662     for (i = 0; i < length; ++i) v = (v << 8) | DECODE_U8();
663     return v;
664 }
665 
666 /* Warning: 'ofs' should be validated by caller! */
read_packet_u8(apf_context * ctx,u32 ofs)667 static u8 read_packet_u8(apf_context* ctx, u32 ofs) {
668     return ctx->packet[ofs];
669 }
670 
do_apf_run(apf_context * ctx)671 static int do_apf_run(apf_context* ctx) {
672 /* Is offset within ram bounds? */
673 #define IN_RAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < ctx->ram_len)
674 /* Is offset within packet bounds? */
675 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < ctx->packet_len)
676 /* Is access to offset |p| length |size| within data bounds? */
677 #define IN_DATA_BOUNDS(p, size) (ENFORCE_UNSIGNED(p) && \
678                                  ENFORCE_UNSIGNED(size) && \
679                                  (p) + (size) <= ctx->ram_len && \
680                                  (p) + (size) >= (p))  /* catch wraparounds */
681 /* Accept packet if not within ram bounds */
682 #define ASSERT_IN_RAM_BOUNDS(p) ASSERT_RETURN(IN_RAM_BOUNDS(p))
683 /* Accept packet if not within packet bounds */
684 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
685 /* Accept packet if not within data bounds */
686 #define ASSERT_IN_DATA_BOUNDS(p, size) ASSERT_RETURN(IN_DATA_BOUNDS(p, size))
687 
688     /* Counters start at end of RAM and count *backwards* so this array takes negative integers. */
689     u32 *counter = (u32*)(ctx->program + ctx->ram_len);
690 
691     /* Count of instructions remaining to execute. This is done to ensure an */
692     /* upper bound on execution time. It should never be hit and is only for */
693     /* safety. Initialize to the number of bytes in the program which is an */
694     /* upper bound on the number of instructions in the program. */
695     u32 instructions_remaining = ctx->program_len;
696 
697     /* APFv6 requires at least 5 u32 counters at the end of ram, this makes counter[-5]++ valid */
698     /* This cannot wrap due to previous check, that enforced program_len & ram_len < 2GiB. */
699     if (ctx->program_len + 20 > ctx->ram_len) return EXCEPTION;
700 
701     /* Only populate if packet long enough, and IP version is IPv4. */
702     /* Note: this doesn't actually check the ethertype... */
703     if ((ctx->packet_len >= ETH_HLEN + IPV4_HLEN)
704         && ((read_packet_u8(ctx, ETH_HLEN) & 0xf0) == 0x40)) {
705         ctx->mem.named.ipv4_header_size = (read_packet_u8(ctx, ETH_HLEN) & 15) * 4;
706     }
707 
708 /* Is access to offset |p| length |size| within output buffer bounds? */
709 #define IN_OUTPUT_BOUNDS(p, size) (ENFORCE_UNSIGNED(p) && \
710                                  ENFORCE_UNSIGNED(size) && \
711                                  (p) + (size) <= ctx->tx_buf_len && \
712                                  (p) + (size) >= (p))
713 /* Accept packet if not write within allocated output buffer */
714 #define ASSERT_IN_OUTPUT_BOUNDS(p, size) ASSERT_RETURN(IN_OUTPUT_BOUNDS(p, size))
715 
716     do {
717       APF_TRACE_HOOK(ctx->pc, ctx->R, ctx->program, ctx->program_len,
718                      ctx->packet, ctx->packet_len, ctx->mem.slot, ctx->ram_len);
719       if (ctx->pc == ctx->program_len + 1) return DROP;
720       if (ctx->pc == ctx->program_len) return PASS;
721       if (ctx->pc > ctx->program_len) return EXCEPTION;
722 
723       {  /* half indent to avoid needless line length... */
724 
725         const u8 bytecode = DECODE_U8();
726         const u8 opcode = EXTRACT_OPCODE(bytecode);
727         const u8 reg_num = EXTRACT_REGISTER(bytecode);
728 #define REG (ctx->R[reg_num])
729 #define OTHER_REG (ctx->R[reg_num ^ 1])
730         /* All instructions have immediate fields, so load them now. */
731         const u8 len_field = EXTRACT_IMM_LENGTH(bytecode);
732         const u8 imm_len = ((len_field + 1u) >> 2) + len_field; /* 0,1,2,3 -> 0,1,2,4 */
733         u32 pktcopy_src_offset = 0;  /* used for various pktdatacopy opcodes */
734         u32 imm = 0;
735         s32 signed_imm = 0;
736         u32 arith_imm;
737         s32 arith_signed_imm;
738         if (len_field != 0) {
739             imm = decode_imm(ctx, imm_len); /* 1st imm, at worst bytes 1-4 past opcode/program_len */
740             /* Sign extend imm into signed_imm. */
741             signed_imm = (s32)(imm << ((4 - imm_len) * 8));
742             signed_imm >>= (4 - imm_len) * 8;
743         }
744 
745         /* See comment at ADD_OPCODE for the reason for ARITH_REG/arith_imm/arith_signed_imm. */
746 #define ARITH_REG (ctx->R[reg_num & ctx->v6])
747         arith_imm = (ctx->v6) ? (len_field ? imm : OTHER_REG) : (reg_num ? ctx->R[1] : imm);
748         arith_signed_imm = (ctx->v6) ? (len_field ? signed_imm : (s32)OTHER_REG) : (reg_num ? (s32)ctx->R[1] : signed_imm);
749 
750         switch (opcode) {
751           case PASSDROP_OPCODE: {  /* APFv6+ */
752             if (len_field > 2) return EXCEPTION;  /* max 64K counters (ie. imm < 64K) */
753             if (imm) {
754                 if (4 * imm > ctx->ram_len) return EXCEPTION;
755                 counter[-(s32)imm]++;
756             }
757             return reg_num ? DROP : PASS;
758           }
759           case LDB_OPCODE:
760           case LDH_OPCODE:
761           case LDW_OPCODE:
762           case LDBX_OPCODE:
763           case LDHX_OPCODE:
764           case LDWX_OPCODE: {
765             u32 load_size = 0;
766             u32 offs = imm;
767             /* Note: this can overflow and actually decrease offs. */
768             if (opcode >= LDBX_OPCODE) offs += ctx->R[1];
769             ASSERT_IN_PACKET_BOUNDS(offs);
770             switch (opcode) {
771               case LDB_OPCODE:
772               case LDBX_OPCODE:
773                 load_size = 1;
774                 break;
775               case LDH_OPCODE:
776               case LDHX_OPCODE:
777                 load_size = 2;
778                 break;
779               case LDW_OPCODE:
780               case LDWX_OPCODE:
781                 load_size = 4;
782                 break;
783               /* Immediately enclosing switch statement guarantees */
784               /* opcode cannot be any other value. */
785             }
786             {
787                 const u32 end_offs = offs + (load_size - 1);
788                 u32 val = 0;
789                 /* Catch overflow/wrap-around. */
790                 ASSERT_RETURN(end_offs >= offs);
791                 ASSERT_IN_PACKET_BOUNDS(end_offs);
792                 while (load_size--) val = (val << 8) | read_packet_u8(ctx, offs++);
793                 REG = val;
794             }
795             break;
796           }
797           case JMP_OPCODE:
798             if (reg_num && !ctx->v6) {  /* APFv6+ */
799                 /* First invocation of APFv6 jmpdata instruction */
800                 counter[-1] = 0x12345678;  /* endianness marker */
801                 counter[-2]++;  /* total packets ++ */
802                 ctx->v6 = (u8)True;
803             }
804             /* This can jump backwards. Infinite looping prevented by instructions_remaining. */
805             ctx->pc += imm;
806             break;
807           case JEQ_OPCODE:
808           case JNE_OPCODE:
809           case JGT_OPCODE:
810           case JLT_OPCODE:
811           case JSET_OPCODE:
812           case JNSET_OPCODE: {
813             u32 cmp_imm = 0;
814             /* Load second immediate field. */
815             if (reg_num == 1) {
816                 cmp_imm = ctx->R[1];
817             } else {
818                 cmp_imm = decode_imm(ctx, imm_len); /* 2nd imm, at worst 8 bytes past prog_len */
819             }
820             switch (opcode) {
821               case JEQ_OPCODE:   if (  ctx->R[0] == cmp_imm ) ctx->pc += imm; break;
822               case JNE_OPCODE:   if (  ctx->R[0] != cmp_imm ) ctx->pc += imm; break;
823               case JGT_OPCODE:   if (  ctx->R[0] >  cmp_imm ) ctx->pc += imm; break;
824               case JLT_OPCODE:   if (  ctx->R[0] <  cmp_imm ) ctx->pc += imm; break;
825               case JSET_OPCODE:  if (  ctx->R[0] &  cmp_imm ) ctx->pc += imm; break;
826               case JNSET_OPCODE: if (!(ctx->R[0] &  cmp_imm)) ctx->pc += imm; break;
827             }
828             break;
829           }
830           case JBSMATCH_OPCODE: {
831             /* Load second immediate field. */
832             u32 cmp_imm = decode_imm(ctx, imm_len); /* 2nd imm, at worst 8 bytes past prog_len */
833             u32 cnt = (cmp_imm >> 11) + 1; /* 1+, up to 32 fits in u16 */
834             u32 len = cmp_imm & 2047; /* 0..2047 */
835             u32 bytes = cnt * len;
836             const u32 last_packet_offs = ctx->R[0] + len - 1;
837             Boolean matched = False;
838             /* bytes = cnt * len is size in bytes of data to compare. */
839             /* pc is offset of program bytes to compare. */
840             /* imm is jump target offset. */
841             /* R0 is offset of packet bytes to compare. */
842             if (bytes > 0xFFFF) return EXCEPTION;
843             /* pc < program_len < ram_len < 2GiB, thus pc + bytes cannot wrap */
844             if (!IN_RAM_BOUNDS(ctx->pc + bytes - 1)) return EXCEPTION;
845             ASSERT_IN_PACKET_BOUNDS(ctx->R[0]);
846             /* Note: this will return EXCEPTION (due to wrap) if imm_len (ie. len) is 0 */
847             ASSERT_RETURN(last_packet_offs >= ctx->R[0]);
848             ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
849             while (cnt--) {
850                 matched |= !memcmp(ctx->program + ctx->pc, ctx->packet + ctx->R[0], len);
851                 /* skip past comparison bytes */
852                 ctx->pc += len;
853             }
854             if (matched ^ !reg_num) ctx->pc += imm;
855             break;
856           }
857           /* There is a difference in APFv4 and APFv6 arithmetic behaviour! */
858           /* APFv4:  R[0] op= Rbit ? R[1] : imm;  (and it thus doesn't make sense to have R=1 && len_field>0) */
859           /* APFv6+: REG  op= len_field ? imm : OTHER_REG;  (note: this is *DIFFERENT* with R=1 len_field==0) */
860           /* Furthermore APFv4 uses unsigned imm (except SH), while APFv6 uses signed_imm for ADD/AND/SH. */
861           case ADD_OPCODE: ARITH_REG += (ctx->v6) ? (u32)arith_signed_imm : arith_imm; break;
862           case MUL_OPCODE: ARITH_REG *= arith_imm; break;
863           case AND_OPCODE: ARITH_REG &= (ctx->v6) ? (u32)arith_signed_imm : arith_imm; break;
864           case OR_OPCODE:  ARITH_REG |= arith_imm; break;
865           case DIV_OPCODE: {  /* see above comment! */
866             const u32 div_operand = arith_imm;
867             ASSERT_RETURN(div_operand);
868             ARITH_REG /= div_operand;
869             break;
870           }
871           case SH_OPCODE: {  /* see above comment! */
872             if (arith_signed_imm >= 0)
873                 ARITH_REG <<= arith_signed_imm;
874             else
875                 ARITH_REG >>= -arith_signed_imm;
876             break;
877           }
878           case LI_OPCODE:
879             REG = (u32)signed_imm;
880             break;
881           case PKTDATACOPY_OPCODE:
882             pktcopy_src_offset = imm;
883             imm = PKTDATACOPYIMM_EXT_OPCODE;
884             FALLTHROUGH;
885           case EXT_OPCODE:
886             if (/* imm >= LDM_EXT_OPCODE &&  -- but note imm is u32 and LDM_EXT_OPCODE is 0 */
887                 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
888                 REG = ctx->mem.slot[imm - LDM_EXT_OPCODE];
889             } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
890                 ctx->mem.slot[imm - STM_EXT_OPCODE] = REG;
891             } else switch (imm) {
892               case NOT_EXT_OPCODE: REG = ~REG;      break;
893               case NEG_EXT_OPCODE: REG = -REG;      break;
894               case MOV_EXT_OPCODE: REG = OTHER_REG; break;
895               case SWAP_EXT_OPCODE: {
896                 u32 tmp = REG;
897                 REG = OTHER_REG;
898                 OTHER_REG = tmp;
899                 break;
900               }
901               case ALLOCATE_EXT_OPCODE:
902                 ASSERT_RETURN(ctx->tx_buf == NULL);
903                 if (reg_num == 0) {
904                     ctx->tx_buf_len = REG;
905                 } else {
906                     ctx->tx_buf_len = decode_be16(ctx); /* 2nd imm, at worst 6 B past prog_len */
907                 }
908                 /* checksumming functions requires minimum 266 byte buffer for correctness */
909                 if (ctx->tx_buf_len < 266) ctx->tx_buf_len = 266;
910                 ctx->tx_buf = apf_allocate_buffer(ctx->caller_ctx, ctx->tx_buf_len);
911                 if (!ctx->tx_buf) {  /* allocate failure */
912                     ctx->tx_buf_len = 0;
913                     counter[-3]++;
914                     return EXCEPTION;
915                 }
916                 memset(ctx->tx_buf, 0, ctx->tx_buf_len);
917                 ctx->mem.named.tx_buf_offset = 0;
918                 break;
919               case TRANSMIT_EXT_OPCODE: {
920                 /* tx_buf_len cannot be large because we'd run out of RAM, */
921                 /* so the above unsigned comparison effectively guarantees casting pkt_len */
922                 /* to a signed value does not result in it going negative. */
923                 u8 ip_ofs = DECODE_U8();              /* 2nd imm, at worst 5 B past prog_len */
924                 u8 csum_ofs = DECODE_U8();            /* 3rd imm, at worst 6 B past prog_len */
925                 u8 csum_start = 0;
926                 u16 partial_csum = 0;
927                 u32 pkt_len = ctx->mem.named.tx_buf_offset;
928                 ASSERT_RETURN(ctx->tx_buf);
929                 /* If pkt_len > allocate_buffer_len, it means sth. wrong */
930                 /* happened and the tx_buf should be deallocated. */
931                 if (pkt_len > ctx->tx_buf_len) {
932                     do_discard_buffer(ctx);
933                     return EXCEPTION;
934                 }
935                 if (csum_ofs < 255) {
936                     csum_start = DECODE_U8();         /* 4th imm, at worst 7 B past prog_len */
937                     partial_csum = decode_be16(ctx);  /* 5th imm, at worst 9 B past prog_len */
938                 }
939                 {
940                     int dscp = apf_internal_csum_and_return_dscp(ctx->tx_buf, (s32)pkt_len, ip_ofs,
941                                                     partial_csum, csum_start, csum_ofs,
942                                                     (Boolean)reg_num);
943                     int ret = apf_internal_do_transmit_buffer(ctx, pkt_len, dscp);
944                     if (ret) { counter[-4]++; return EXCEPTION; } /* transmit failure */
945                 }
946                 break;
947               }
948               case EPKTDATACOPYIMM_EXT_OPCODE:  /* 41 */
949               case EPKTDATACOPYR1_EXT_OPCODE:   /* 42 */
950                 pktcopy_src_offset = ctx->R[0];
951                 FALLTHROUGH;
952               case PKTDATACOPYIMM_EXT_OPCODE: { /* 65536 */
953                 u32 dst_offs = ctx->mem.named.tx_buf_offset;
954                 u32 copy_len = ctx->R[1];
955                 if (imm != EPKTDATACOPYR1_EXT_OPCODE) {
956                     copy_len = DECODE_U8();  /* 2nd imm, at worst 8 bytes past prog_len */
957                 }
958                 ASSERT_RETURN(ctx->tx_buf);
959                 ASSERT_IN_OUTPUT_BOUNDS(dst_offs, copy_len);
960                 if (reg_num == 0) {  /* copy from packet */
961                     const u32 last_packet_offs = pktcopy_src_offset + copy_len - 1;
962                     ASSERT_IN_PACKET_BOUNDS(pktcopy_src_offset);
963                     ASSERT_RETURN(last_packet_offs >= pktcopy_src_offset);
964                     ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
965                     memcpy(ctx->tx_buf + dst_offs, ctx->packet + pktcopy_src_offset, copy_len);
966                 } else {  /* copy from data */
967                     ASSERT_IN_RAM_BOUNDS(pktcopy_src_offset + copy_len - 1);
968                     memcpy(ctx->tx_buf + dst_offs, ctx->program + pktcopy_src_offset, copy_len);
969                 }
970                 dst_offs += copy_len;
971                 ctx->mem.named.tx_buf_offset = dst_offs;
972                 break;
973               }
974               case JDNSQMATCH_EXT_OPCODE:       /* 43 */
975               case JDNSAMATCH_EXT_OPCODE:       /* 44 */
976               case JDNSQMATCHSAFE_EXT_OPCODE:   /* 45 */
977               case JDNSAMATCHSAFE_EXT_OPCODE: { /* 46 */
978                 u32 jump_offs = decode_imm(ctx, imm_len); /* 2nd imm, at worst 8 B past prog_len */
979                 int qtype = -1;
980                 if (imm & 1) { /* JDNSQMATCH & JDNSQMATCHSAFE are *odd* extended opcodes */
981                     qtype = DECODE_U8();  /* 3rd imm, at worst 9 bytes past prog_len */
982                 }
983                 {
984                     u32 udp_payload_offset = ctx->R[0];
985                     match_result_type match_rst = apf_internal_match_names(ctx->program + ctx->pc,
986                                                               ctx->program + ctx->program_len,
987                                                               ctx->packet + udp_payload_offset,
988                                                               ctx->packet_len - udp_payload_offset,
989                                                               qtype);
990                     if (match_rst == error_program) return EXCEPTION;
991                     if (match_rst == error_packet) {
992                         counter[-5]++; /* increment error dns packet counter */
993                         return (imm >= JDNSQMATCHSAFE_EXT_OPCODE) ? PASS : DROP;
994                     }
995                     while (ctx->pc + 1 < ctx->program_len &&
996                            (ctx->program[ctx->pc] || ctx->program[ctx->pc + 1])) {
997                         ctx->pc++;
998                     }
999                     ctx->pc += 2;  /* skip the final double 0 needle end */
1000                     /* relies on reg_num in {0,1} and match_rst being {False=0, True=1} */
1001                     if (!(reg_num ^ (u32)match_rst)) ctx->pc += jump_offs;
1002                 }
1003                 break;
1004               }
1005               case EWRITE1_EXT_OPCODE:
1006               case EWRITE2_EXT_OPCODE:
1007               case EWRITE4_EXT_OPCODE: {
1008                 const u32 write_len = 1 << (imm - EWRITE1_EXT_OPCODE);
1009                 u32 i;
1010                 ASSERT_RETURN(ctx->tx_buf);
1011                 ASSERT_IN_OUTPUT_BOUNDS(ctx->mem.named.tx_buf_offset, write_len);
1012                 for (i = 0; i < write_len; ++i) {
1013                     ctx->tx_buf[ctx->mem.named.tx_buf_offset++] =
1014                         (u8)(REG >> (write_len - 1 - i) * 8);
1015                 }
1016                 break;
1017               }
1018               case JONEOF_EXT_OPCODE: {
1019                 u32 jump_offs = decode_imm(ctx, imm_len); /* 2nd imm, at worst 8 B past prog_len */
1020                 u8 imm3 = DECODE_U8();  /* 3rd imm, at worst 9 bytes past prog_len */
1021                 Boolean jmp = imm3 & 1;  /* =0 jmp on match, =1 jmp on no match */
1022                 u8 len = ((imm3 >> 1) & 3) + 1;  /* size [1..4] in bytes of an element */
1023                 u8 cnt = (imm3 >> 3) + 2;  /* number [2..33] of elements in set */
1024                 if (ctx->pc + cnt * len > ctx->program_len) return EXCEPTION;
1025                 while (cnt--) {
1026                     u32 v = 0;
1027                     int i;
1028                     for (i = 0; i < len; ++i) v = (v << 8) | DECODE_U8();
1029                     if (REG == v) jmp ^= True;
1030                 }
1031                 if (jmp) ctx->pc += jump_offs;
1032                 break;
1033               }
1034               case EXCEPTIONBUFFER_EXT_OPCODE: {
1035                 ctx->except_buf_sz = decode_be16(ctx);
1036                 break;
1037               }
1038               default:  /* Unknown extended opcode */
1039                 return EXCEPTION;  /* Bail out */
1040             }
1041             break;
1042           case LDDW_OPCODE:
1043           case STDW_OPCODE:
1044             if (ctx->v6) {
1045                 if (!imm) return EXCEPTION;
1046                 if (imm > 0xFFFF) return EXCEPTION;
1047                 if (imm * 4 > ctx->ram_len) return EXCEPTION;
1048                 if (opcode == LDDW_OPCODE) {
1049                     REG = counter[-(s32)imm];
1050                 } else {
1051                     counter[-(s32)imm] = REG;
1052                 }
1053             } else {
1054                 u32 size = 4;
1055                 u32 offs = OTHER_REG + (u32)signed_imm;
1056                 /* Negative offsets wrap around the end of the address space. */
1057                 /* This allows us to efficiently access the end of the */
1058                 /* address space with one-byte immediates without using %=. */
1059                 if (offs & 0x80000000) offs += ctx->ram_len;  /* unsigned overflow intended */
1060                 ASSERT_IN_DATA_BOUNDS(offs, size);
1061                 if (opcode == LDDW_OPCODE) {
1062                     u32 val = 0;
1063                     while (size--) val = (val << 8) | ctx->program[offs++];
1064                     REG = val;
1065                 } else {
1066                     u32 val = REG;
1067                     while (size--) {
1068                         ctx->program[offs++] = (val >> 24);
1069                         val <<= 8;
1070                     }
1071                 }
1072             }
1073             break;
1074           case WRITE_OPCODE: {
1075             ASSERT_RETURN(ctx->tx_buf);
1076             ASSERT_RETURN(len_field);
1077             {
1078                 const u32 write_len = 1 << (len_field - 1);
1079                 u32 i;
1080                 ASSERT_IN_OUTPUT_BOUNDS(ctx->mem.named.tx_buf_offset, write_len);
1081                 for (i = 0; i < write_len; ++i) {
1082                     ctx->tx_buf[ctx->mem.named.tx_buf_offset++] =
1083                         (u8)(imm >> (write_len - 1 - i) * 8);
1084                 }
1085             }
1086             break;
1087           }
1088           default:  /* Unknown opcode */
1089             return EXCEPTION;  /* Bail out */
1090         }
1091       }
1092     } while (instructions_remaining--);
1093     return EXCEPTION;
1094 }
1095 
apf_runner(void * ctx,u32 * const program,const u32 program_len,const u32 ram_len,const u8 * const packet,const u32 packet_len,const u32 filter_age_16384ths)1096 static int apf_runner(void* ctx, u32* const program, const u32 program_len,
1097                       const u32 ram_len, const u8* const packet,
1098                       const u32 packet_len, const u32 filter_age_16384ths) {
1099     /* Due to direct 32-bit read/write access to counters at end of ram */
1100     /* APFv6 interpreter requires program & ram_len to be 4 byte aligned. */
1101     if (3 & (uintptr_t)program) return EXCEPTION;
1102     if (3 & ram_len) return EXCEPTION;
1103 
1104     /* We rely on ram_len + 65536 not overflowing, so require ram_len < 2GiB */
1105     /* Similarly LDDW/STDW have special meaning for negative ram offsets. */
1106     /* We also don't want garbage like program_len == 0xFFFFFFFF */
1107     if ((program_len | ram_len) >> 31) return EXCEPTION;
1108 
1109     {
1110         apf_context apf_ctx = { 0 };
1111         int ret;
1112 
1113         apf_ctx.ptr_size = sizeof(void*);
1114         apf_ctx.caller_ctx = ctx;
1115         apf_ctx.program = (u8*)program;
1116         apf_ctx.program_len = program_len;
1117         apf_ctx.ram_len = ram_len;
1118         apf_ctx.packet = packet;
1119         apf_ctx.packet_len = packet_len;
1120         /* Fill in pre-filled memory slot values. */
1121         apf_ctx.mem.named.program_size = program_len;
1122         apf_ctx.mem.named.ram_len = ram_len;
1123         apf_ctx.mem.named.packet_size = packet_len;
1124         apf_ctx.mem.named.apf_version = apf_version();
1125         apf_ctx.mem.named.filter_age = filter_age_16384ths >> 14;
1126         apf_ctx.mem.named.filter_age_16384ths = filter_age_16384ths;
1127 
1128         ret = do_apf_run(&apf_ctx);
1129         if (apf_ctx.tx_buf) do_discard_buffer(&apf_ctx);
1130         /* Convert any exceptions internal to the program to just normal 'PASS' */
1131         if (ret >= EXCEPTION) {
1132             u16 buf_size = apf_ctx.except_buf_sz;
1133             if (buf_size >= sizeof(apf_ctx) && apf_ctx.program_len + buf_size <= apf_ctx.ram_len) {
1134                 u8* buf = apf_ctx.program + apf_ctx.program_len;
1135                 memcpy(buf, &apf_ctx, sizeof(apf_ctx));
1136                 buf_size -= sizeof(apf_ctx);
1137                 buf += sizeof(apf_ctx);
1138                 if (buf_size > apf_ctx.packet_len) buf_size = apf_ctx.packet_len;
1139                 memcpy(buf, apf_ctx.packet, buf_size);
1140             }
1141             ret = PASS;
1142         }
1143         return ret;
1144     }
1145 }
1146 
apf_run(void * ctx,u32 * const program,const u32 program_len,const u32 ram_len,const u8 * const packet,const u32 packet_len,const u32 filter_age_16384ths)1147 int apf_run(void* ctx, u32* const program, const u32 program_len,
1148             const u32 ram_len, const u8* const packet,
1149             const u32 packet_len, const u32 filter_age_16384ths) {
1150     /* Any valid ethernet packet should be at least ETH_HLEN long... */
1151     if (!packet) return EXCEPTION;
1152     if (packet_len < ETH_HLEN) return EXCEPTION;
1153 
1154     return apf_runner(ctx, program, program_len, ram_len, packet, packet_len, filter_age_16384ths);
1155 }
1156