1 /*
2  * Copyright 2018, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "apf_interpreter.h"
18 
19 #include <string.h> // For memcmp
20 
21 #include "apf.h"
22 
23 // User hook for interpreter debug tracing.
24 #ifdef APF_TRACE_HOOK
25 extern void APF_TRACE_HOOK(uint32_t pc, const uint32_t* regs, const uint8_t* program,
26                            uint32_t program_len, const uint8_t *packet, uint32_t packet_len,
27                            const uint32_t* memory, uint32_t ram_len);
28 #else
29 #define APF_TRACE_HOOK(pc, regs, program, program_len, packet, packet_len, memory, memory_len) \
30     do { /* nop*/                                                                              \
31     } while (0)
32 #endif
33 
34 // Return code indicating "packet" should accepted.
35 #define PASS_PACKET 1
36 // Return code indicating "packet" should be dropped.
37 #define DROP_PACKET 0
38 // Verify an internal condition and accept packet if it fails.
39 #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET
40 // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error.
41 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
42 // superfluous ">= 0" with unsigned expressions generates compile warnings.
43 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
44 
accept_packet(uint8_t * program,uint32_t program_len,uint32_t ram_len,const uint8_t * packet,uint32_t packet_len,uint32_t filter_age)45 int accept_packet(uint8_t* program, uint32_t program_len, uint32_t ram_len,
46                   const uint8_t* packet, uint32_t packet_len,
47                   uint32_t filter_age) {
48 // Is offset within program bounds?
49 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len)
50 // Is offset within packet bounds?
51 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len)
52 // Is access to offset |p| length |size| within data bounds?
53 #define IN_DATA_BOUNDS(p, size) (ENFORCE_UNSIGNED(p) && \
54                                  ENFORCE_UNSIGNED(size) && \
55                                  (p) + (size) <= ram_len && \
56                                  (p) >= program_len && \
57                                  (p) + (size) >= (p))  // catch wraparounds
58 // Accept packet if not within program bounds
59 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p))
60 // Accept packet if not within packet bounds
61 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
62 // Accept packet if not within data bounds
63 #define ASSERT_IN_DATA_BOUNDS(p, size) ASSERT_RETURN(IN_DATA_BOUNDS(p, size))
64 
65   // Program counter.
66   uint32_t pc = 0;
67 // Accept packet if not within program or not ahead of program counter
68 #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc)
69   // Memory slot values.
70   uint32_t memory[MEMORY_ITEMS] = {};
71   // Fill in pre-filled memory slot values.
72   memory[MEMORY_OFFSET_PROGRAM_SIZE] = program_len;
73   memory[MEMORY_OFFSET_DATA_SIZE] = ram_len;
74   memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len;
75   memory[MEMORY_OFFSET_FILTER_AGE] = filter_age;
76   ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE);
77   // Only populate if IP version is IPv4.
78   if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) {
79       memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4;
80   }
81   // Register values.
82   uint32_t registers[2] = {};
83   // Count of instructions remaining to execute. This is done to ensure an
84   // upper bound on execution time. It should never be hit and is only for
85   // safety. Initialize to the number of bytes in the program which is an
86   // upper bound on the number of instructions in the program.
87   uint32_t instructions_remaining = program_len;
88 
89   do {
90       APF_TRACE_HOOK(pc, registers, program, program_len, packet, packet_len, memory, ram_len);
91       if (pc == program_len) {
92           return PASS_PACKET;
93       } else if (pc == (program_len + 1)) {
94           return DROP_PACKET;
95       }
96       ASSERT_IN_PROGRAM_BOUNDS(pc);
97       const uint8_t bytecode = program[pc++];
98       const uint32_t opcode = EXTRACT_OPCODE(bytecode);
99       const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
100 #define REG (registers[reg_num])
101 #define OTHER_REG (registers[reg_num ^ 1])
102       // All instructions have immediate fields, so load them now.
103       const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
104       uint32_t imm = 0;
105       int32_t signed_imm = 0;
106       if (len_field != 0) {
107           const uint32_t imm_len = 1 << (len_field - 1);
108           ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1);
109           uint32_t i;
110           for (i = 0; i < imm_len; i++)
111               imm = (imm << 8) | program[pc++];
112           // Sign extend imm into signed_imm.
113           signed_imm = imm << ((4 - imm_len) * 8);
114           signed_imm >>= (4 - imm_len) * 8;
115       }
116 
117       switch (opcode) {
118           case LDB_OPCODE:
119           case LDH_OPCODE:
120           case LDW_OPCODE:
121           case LDBX_OPCODE:
122           case LDHX_OPCODE:
123           case LDWX_OPCODE: {
124               uint32_t offs = imm;
125               if (opcode >= LDBX_OPCODE) {
126                   // Note: this can overflow and actually decrease offs.
127                   offs += registers[1];
128               }
129               ASSERT_IN_PACKET_BOUNDS(offs);
130               uint32_t load_size;
131               switch (opcode) {
132                   case LDB_OPCODE:
133                   case LDBX_OPCODE:
134                     load_size = 1;
135                     break;
136                   case LDH_OPCODE:
137                   case LDHX_OPCODE:
138                     load_size = 2;
139                     break;
140                   case LDW_OPCODE:
141                   case LDWX_OPCODE:
142                     load_size = 4;
143                     break;
144                   // Immediately enclosing switch statement guarantees
145                   // opcode cannot be any other value.
146               }
147               const uint32_t end_offs = offs + (load_size - 1);
148               // Catch overflow/wrap-around.
149               ASSERT_RETURN(end_offs >= offs);
150               ASSERT_IN_PACKET_BOUNDS(end_offs);
151               uint32_t val = 0;
152               while (load_size--)
153                   val = (val << 8) | packet[offs++];
154               REG = val;
155               break;
156           }
157           case JMP_OPCODE:
158               // This can jump backwards. Infinite looping prevented by instructions_remaining.
159               pc += imm;
160               break;
161           case JEQ_OPCODE:
162           case JNE_OPCODE:
163           case JGT_OPCODE:
164           case JLT_OPCODE:
165           case JSET_OPCODE:
166           case JNEBS_OPCODE: {
167               // Load second immediate field.
168               uint32_t cmp_imm = 0;
169               if (reg_num == 1) {
170                   cmp_imm = registers[1];
171               } else if (len_field != 0) {
172                   uint32_t cmp_imm_len = 1 << (len_field - 1);
173                   ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1);
174                   uint32_t i;
175                   for (i = 0; i < cmp_imm_len; i++)
176                       cmp_imm = (cmp_imm << 8) | program[pc++];
177               }
178               switch (opcode) {
179                   case JEQ_OPCODE:
180                       if (registers[0] == cmp_imm)
181                           pc += imm;
182                       break;
183                   case JNE_OPCODE:
184                       if (registers[0] != cmp_imm)
185                           pc += imm;
186                       break;
187                   case JGT_OPCODE:
188                       if (registers[0] > cmp_imm)
189                           pc += imm;
190                       break;
191                   case JLT_OPCODE:
192                       if (registers[0] < cmp_imm)
193                           pc += imm;
194                       break;
195                   case JSET_OPCODE:
196                       if (registers[0] & cmp_imm)
197                           pc += imm;
198                       break;
199                   case JNEBS_OPCODE: {
200                       // cmp_imm is size in bytes of data to compare.
201                       // pc is offset of program bytes to compare.
202                       // imm is jump target offset.
203                       // REG is offset of packet bytes to compare.
204                       ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1);
205                       ASSERT_IN_PACKET_BOUNDS(REG);
206                       const uint32_t last_packet_offs = REG + cmp_imm - 1;
207                       ASSERT_RETURN(last_packet_offs >= REG);
208                       ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
209                       if (memcmp(program + pc, packet + REG, cmp_imm))
210                           pc += imm;
211                       // skip past comparison bytes
212                       pc += cmp_imm;
213                       break;
214                   }
215               }
216               break;
217           }
218           case ADD_OPCODE:
219               registers[0] += reg_num ? registers[1] : imm;
220               break;
221           case MUL_OPCODE:
222               registers[0] *= reg_num ? registers[1] : imm;
223               break;
224           case DIV_OPCODE: {
225               const uint32_t div_operand = reg_num ? registers[1] : imm;
226               ASSERT_RETURN(div_operand);
227               registers[0] /= div_operand;
228               break;
229           }
230           case AND_OPCODE:
231               registers[0] &= reg_num ? registers[1] : imm;
232               break;
233           case OR_OPCODE:
234               registers[0] |= reg_num ? registers[1] : imm;
235               break;
236           case SH_OPCODE: {
237               const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm;
238               if (shift_val > 0)
239                   registers[0] <<= shift_val;
240               else
241                   registers[0] >>= -shift_val;
242               break;
243           }
244           case LI_OPCODE:
245               REG = signed_imm;
246               break;
247           case EXT_OPCODE:
248               if (
249 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
250 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
251 #if LDM_EXT_OPCODE == 0
252                   ENFORCE_UNSIGNED(imm) &&
253 #else
254                   imm >= LDM_EXT_OPCODE &&
255 #endif
256                   imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
257                 REG = memory[imm - LDM_EXT_OPCODE];
258               } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
259                 memory[imm - STM_EXT_OPCODE] = REG;
260               } else switch (imm) {
261                   case NOT_EXT_OPCODE:
262                     REG = ~REG;
263                     break;
264                   case NEG_EXT_OPCODE:
265                     REG = -REG;
266                     break;
267                   case SWAP_EXT_OPCODE: {
268                     uint32_t tmp = REG;
269                     REG = OTHER_REG;
270                     OTHER_REG = tmp;
271                     break;
272                   }
273                   case MOV_EXT_OPCODE:
274                     REG = OTHER_REG;
275                     break;
276                   // Unknown extended opcode
277                   default:
278                     // Bail out
279                     return PASS_PACKET;
280               }
281               break;
282           case LDDW_OPCODE: {
283               uint32_t offs = OTHER_REG + signed_imm;
284               uint32_t size = 4;
285               uint32_t val = 0;
286               // Negative offsets wrap around the end of the address space.
287               // This allows us to efficiently access the end of the
288               // address space with one-byte immediates without using %=.
289               if (offs & 0x80000000) {
290                   offs = ram_len + offs;  // unsigned overflow intended
291               }
292               ASSERT_IN_DATA_BOUNDS(offs, size);
293               while (size--)
294                   val = (val << 8) | program[offs++];
295               REG = val;
296               break;
297           }
298           case STDW_OPCODE: {
299               uint32_t offs = OTHER_REG + signed_imm;
300               uint32_t size = 4;
301               uint32_t val = REG;
302               // Negative offsets wrap around the end of the address space.
303               // This allows us to efficiently access the end of the
304               // address space with one-byte immediates without using %=.
305               if (offs & 0x80000000) {
306                   offs = ram_len + offs;  // unsigned overflow intended
307               }
308               ASSERT_IN_DATA_BOUNDS(offs, size);
309               while (size--) {
310                   program[offs++] = (val >> 24);
311                   val <<= 8;
312               }
313               break;
314           }
315           // Unknown opcode
316           default:
317               // Bail out
318               return PASS_PACKET;
319       }
320   } while (instructions_remaining--);
321   return PASS_PACKET;
322 }
323