1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "apf_interpreter.h"
18 
19 #include <string.h> // For memcmp
20 
21 #include "apf.h"
22 
23 // Return code indicating "packet" should accepted.
24 #define PASS_PACKET 1
25 // Return code indicating "packet" should be dropped.
26 #define DROP_PACKET 0
27 // Verify an internal condition and accept packet if it fails.
28 #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET
29 // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error.
30 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
31 // superfluous ">= 0" with unsigned expressions generates compile warnings.
32 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
33 
34 /**
35  * Runs a packet filtering program over a packet.
36  *
37  * @param program the program bytecode.
38  * @param program_len the length of {@code apf_program} in bytes.
39  * @param packet the packet bytes, starting from the 802.3 header and not
40  *               including any CRC bytes at the end.
41  * @param packet_len the length of {@code packet} in bytes.
42  * @param filter_age the number of seconds since the filter was programmed.
43  *
44  * @return non-zero if packet should be passed to AP, zero if
45  *         packet should be dropped.
46  */
accept_packet(const uint8_t * program,uint32_t program_len,const uint8_t * packet,uint32_t packet_len,uint32_t filter_age)47 int accept_packet(const uint8_t* program, uint32_t program_len,
48                   const uint8_t* packet, uint32_t packet_len,
49                   uint32_t filter_age) {
50 // Is offset within program bounds?
51 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len)
52 // Is offset within packet bounds?
53 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len)
54 // Accept packet if not within program bounds
55 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p))
56 // Accept packet if not within packet bounds
57 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
58   // Program counter.
59   uint32_t pc = 0;
60 // Accept packet if not within program or not ahead of program counter
61 #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc)
62   // Memory slot values.
63   uint32_t memory[MEMORY_ITEMS] = {};
64   // Fill in pre-filled memory slot values.
65   memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len;
66   memory[MEMORY_OFFSET_FILTER_AGE] = filter_age;
67   ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE);
68   // Only populate if IP version is IPv4.
69   if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) {
70       memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4;
71   }
72   // Register values.
73   uint32_t registers[2] = {};
74   // Count of instructions remaining to execute. This is done to ensure an
75   // upper bound on execution time. It should never be hit and is only for
76   // safety. Initialize to the number of bytes in the program which is an
77   // upper bound on the number of instructions in the program.
78   uint32_t instructions_remaining = program_len;
79 
80   do {
81       if (pc == program_len) {
82           return PASS_PACKET;
83       } else if (pc == (program_len + 1)) {
84           return DROP_PACKET;
85       }
86       ASSERT_IN_PROGRAM_BOUNDS(pc);
87       const uint8_t bytecode = program[pc++];
88       const uint32_t opcode = EXTRACT_OPCODE(bytecode);
89       const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
90 #define REG (registers[reg_num])
91 #define OTHER_REG (registers[reg_num ^ 1])
92       // All instructions have immediate fields, so load them now.
93       const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
94       uint32_t imm = 0;
95       int32_t signed_imm = 0;
96       if (len_field != 0) {
97           const uint32_t imm_len = 1 << (len_field - 1);
98           ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1);
99           uint32_t i;
100           for (i = 0; i < imm_len; i++)
101               imm = (imm << 8) | program[pc++];
102           // Sign extend imm into signed_imm.
103           signed_imm = imm << ((4 - imm_len) * 8);
104           signed_imm >>= (4 - imm_len) * 8;
105       }
106       switch (opcode) {
107           case LDB_OPCODE:
108           case LDH_OPCODE:
109           case LDW_OPCODE:
110           case LDBX_OPCODE:
111           case LDHX_OPCODE:
112           case LDWX_OPCODE: {
113               uint32_t offs = imm;
114               if (opcode >= LDBX_OPCODE) {
115                   // Note: this can overflow and actually decrease offs.
116                   offs += registers[1];
117               }
118               ASSERT_IN_PACKET_BOUNDS(offs);
119               uint32_t load_size;
120               switch (opcode) {
121                   case LDB_OPCODE:
122                   case LDBX_OPCODE:
123                     load_size = 1;
124                     break;
125                   case LDH_OPCODE:
126                   case LDHX_OPCODE:
127                     load_size = 2;
128                     break;
129                   case LDW_OPCODE:
130                   case LDWX_OPCODE:
131                     load_size = 4;
132                     break;
133                   // Immediately enclosing switch statement guarantees
134                   // opcode cannot be any other value.
135               }
136               const uint32_t end_offs = offs + (load_size - 1);
137               // Catch overflow/wrap-around.
138               ASSERT_RETURN(end_offs >= offs);
139               ASSERT_IN_PACKET_BOUNDS(end_offs);
140               uint32_t val = 0;
141               while (load_size--)
142                   val = (val << 8) | packet[offs++];
143               REG = val;
144               break;
145           }
146           case JMP_OPCODE:
147               // This can jump backwards. Infinite looping prevented by instructions_remaining.
148               pc += imm;
149               break;
150           case JEQ_OPCODE:
151           case JNE_OPCODE:
152           case JGT_OPCODE:
153           case JLT_OPCODE:
154           case JSET_OPCODE:
155           case JNEBS_OPCODE: {
156               // Load second immediate field.
157               uint32_t cmp_imm = 0;
158               if (reg_num == 1) {
159                   cmp_imm = registers[1];
160               } else if (len_field != 0) {
161                   uint32_t cmp_imm_len = 1 << (len_field - 1);
162                   ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1);
163                   uint32_t i;
164                   for (i = 0; i < cmp_imm_len; i++)
165                       cmp_imm = (cmp_imm << 8) | program[pc++];
166               }
167               switch (opcode) {
168                   case JEQ_OPCODE:
169                       if (registers[0] == cmp_imm)
170                           pc += imm;
171                       break;
172                   case JNE_OPCODE:
173                       if (registers[0] != cmp_imm)
174                           pc += imm;
175                       break;
176                   case JGT_OPCODE:
177                       if (registers[0] > cmp_imm)
178                           pc += imm;
179                       break;
180                   case JLT_OPCODE:
181                       if (registers[0] < cmp_imm)
182                           pc += imm;
183                       break;
184                   case JSET_OPCODE:
185                       if (registers[0] & cmp_imm)
186                           pc += imm;
187                       break;
188                   case JNEBS_OPCODE: {
189                       // cmp_imm is size in bytes of data to compare.
190                       // pc is offset of program bytes to compare.
191                       // imm is jump target offset.
192                       // REG is offset of packet bytes to compare.
193                       ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1);
194                       ASSERT_IN_PACKET_BOUNDS(REG);
195                       const uint32_t last_packet_offs = REG + cmp_imm - 1;
196                       ASSERT_RETURN(last_packet_offs >= REG);
197                       ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
198                       if (memcmp(program + pc, packet + REG, cmp_imm))
199                           pc += imm;
200                       // skip past comparison bytes
201                       pc += cmp_imm;
202                       break;
203                   }
204               }
205               break;
206           }
207           case ADD_OPCODE:
208               registers[0] += reg_num ? registers[1] : imm;
209               break;
210           case MUL_OPCODE:
211               registers[0] *= reg_num ? registers[1] : imm;
212               break;
213           case DIV_OPCODE: {
214               const uint32_t div_operand = reg_num ? registers[1] : imm;
215               ASSERT_RETURN(div_operand);
216               registers[0] /= div_operand;
217               break;
218           }
219           case AND_OPCODE:
220               registers[0] &= reg_num ? registers[1] : imm;
221               break;
222           case OR_OPCODE:
223               registers[0] |= reg_num ? registers[1] : imm;
224               break;
225           case SH_OPCODE: {
226               const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm;
227               if (shift_val > 0)
228                   registers[0] <<= shift_val;
229               else
230                   registers[0] >>= -shift_val;
231               break;
232           }
233           case LI_OPCODE:
234               REG = signed_imm;
235               break;
236           case EXT_OPCODE:
237               if (
238 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
239 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
240 #if LDM_EXT_OPCODE == 0
241                   ENFORCE_UNSIGNED(imm) &&
242 #else
243                   imm >= LDM_EXT_OPCODE &&
244 #endif
245                   imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
246                 REG = memory[imm - LDM_EXT_OPCODE];
247               } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
248                 memory[imm - STM_EXT_OPCODE] = REG;
249               } else switch (imm) {
250                   case NOT_EXT_OPCODE:
251                     REG = ~REG;
252                     break;
253                   case NEG_EXT_OPCODE:
254                     REG = -REG;
255                     break;
256                   case SWAP_EXT_OPCODE: {
257                     uint32_t tmp = REG;
258                     REG = OTHER_REG;
259                     OTHER_REG = tmp;
260                     break;
261                   }
262                   case MOV_EXT_OPCODE:
263                     REG = OTHER_REG;
264                     break;
265                   // Unknown extended opcode
266                   default:
267                     // Bail out
268                     return PASS_PACKET;
269               }
270               break;
271           // Unknown opcode
272           default:
273               // Bail out
274               return PASS_PACKET;
275       }
276   } while (instructions_remaining--);
277   return PASS_PACKET;
278 }
279