1 /*
2 * Copyright 2016, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "apf_interpreter.h"
18
19 #include <string.h> // For memcmp
20
21 #include "apf.h"
22
23 // Return code indicating "packet" should accepted.
24 #define PASS_PACKET 1
25 // Return code indicating "packet" should be dropped.
26 #define DROP_PACKET 0
27 // Verify an internal condition and accept packet if it fails.
28 #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET
29 // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error.
30 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
31 // superfluous ">= 0" with unsigned expressions generates compile warnings.
32 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
33
34 /**
35 * Runs a packet filtering program over a packet.
36 *
37 * @param program the program bytecode.
38 * @param program_len the length of {@code apf_program} in bytes.
39 * @param packet the packet bytes, starting from the 802.3 header and not
40 * including any CRC bytes at the end.
41 * @param packet_len the length of {@code packet} in bytes.
42 * @param filter_age the number of seconds since the filter was programmed.
43 *
44 * @return non-zero if packet should be passed to AP, zero if
45 * packet should be dropped.
46 */
accept_packet(const uint8_t * program,uint32_t program_len,const uint8_t * packet,uint32_t packet_len,uint32_t filter_age)47 int accept_packet(const uint8_t* program, uint32_t program_len,
48 const uint8_t* packet, uint32_t packet_len,
49 uint32_t filter_age) {
50 // Is offset within program bounds?
51 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len)
52 // Is offset within packet bounds?
53 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len)
54 // Accept packet if not within program bounds
55 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p))
56 // Accept packet if not within packet bounds
57 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
58 // Program counter.
59 uint32_t pc = 0;
60 // Accept packet if not within program or not ahead of program counter
61 #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc)
62 // Memory slot values.
63 uint32_t memory[MEMORY_ITEMS] = {};
64 // Fill in pre-filled memory slot values.
65 memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len;
66 memory[MEMORY_OFFSET_FILTER_AGE] = filter_age;
67 ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE);
68 // Only populate if IP version is IPv4.
69 if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) {
70 memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4;
71 }
72 // Register values.
73 uint32_t registers[2] = {};
74 // Count of instructions remaining to execute. This is done to ensure an
75 // upper bound on execution time. It should never be hit and is only for
76 // safety. Initialize to the number of bytes in the program which is an
77 // upper bound on the number of instructions in the program.
78 uint32_t instructions_remaining = program_len;
79
80 do {
81 if (pc == program_len) {
82 return PASS_PACKET;
83 } else if (pc == (program_len + 1)) {
84 return DROP_PACKET;
85 }
86 ASSERT_IN_PROGRAM_BOUNDS(pc);
87 const uint8_t bytecode = program[pc++];
88 const uint32_t opcode = EXTRACT_OPCODE(bytecode);
89 const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
90 #define REG (registers[reg_num])
91 #define OTHER_REG (registers[reg_num ^ 1])
92 // All instructions have immediate fields, so load them now.
93 const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
94 uint32_t imm = 0;
95 int32_t signed_imm = 0;
96 if (len_field != 0) {
97 const uint32_t imm_len = 1 << (len_field - 1);
98 ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1);
99 uint32_t i;
100 for (i = 0; i < imm_len; i++)
101 imm = (imm << 8) | program[pc++];
102 // Sign extend imm into signed_imm.
103 signed_imm = imm << ((4 - imm_len) * 8);
104 signed_imm >>= (4 - imm_len) * 8;
105 }
106 switch (opcode) {
107 case LDB_OPCODE:
108 case LDH_OPCODE:
109 case LDW_OPCODE:
110 case LDBX_OPCODE:
111 case LDHX_OPCODE:
112 case LDWX_OPCODE: {
113 uint32_t offs = imm;
114 if (opcode >= LDBX_OPCODE) {
115 // Note: this can overflow and actually decrease offs.
116 offs += registers[1];
117 }
118 ASSERT_IN_PACKET_BOUNDS(offs);
119 uint32_t load_size;
120 switch (opcode) {
121 case LDB_OPCODE:
122 case LDBX_OPCODE:
123 load_size = 1;
124 break;
125 case LDH_OPCODE:
126 case LDHX_OPCODE:
127 load_size = 2;
128 break;
129 case LDW_OPCODE:
130 case LDWX_OPCODE:
131 load_size = 4;
132 break;
133 // Immediately enclosing switch statement guarantees
134 // opcode cannot be any other value.
135 }
136 const uint32_t end_offs = offs + (load_size - 1);
137 // Catch overflow/wrap-around.
138 ASSERT_RETURN(end_offs >= offs);
139 ASSERT_IN_PACKET_BOUNDS(end_offs);
140 uint32_t val = 0;
141 while (load_size--)
142 val = (val << 8) | packet[offs++];
143 REG = val;
144 break;
145 }
146 case JMP_OPCODE:
147 // This can jump backwards. Infinite looping prevented by instructions_remaining.
148 pc += imm;
149 break;
150 case JEQ_OPCODE:
151 case JNE_OPCODE:
152 case JGT_OPCODE:
153 case JLT_OPCODE:
154 case JSET_OPCODE:
155 case JNEBS_OPCODE: {
156 // Load second immediate field.
157 uint32_t cmp_imm = 0;
158 if (reg_num == 1) {
159 cmp_imm = registers[1];
160 } else if (len_field != 0) {
161 uint32_t cmp_imm_len = 1 << (len_field - 1);
162 ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1);
163 uint32_t i;
164 for (i = 0; i < cmp_imm_len; i++)
165 cmp_imm = (cmp_imm << 8) | program[pc++];
166 }
167 switch (opcode) {
168 case JEQ_OPCODE:
169 if (registers[0] == cmp_imm)
170 pc += imm;
171 break;
172 case JNE_OPCODE:
173 if (registers[0] != cmp_imm)
174 pc += imm;
175 break;
176 case JGT_OPCODE:
177 if (registers[0] > cmp_imm)
178 pc += imm;
179 break;
180 case JLT_OPCODE:
181 if (registers[0] < cmp_imm)
182 pc += imm;
183 break;
184 case JSET_OPCODE:
185 if (registers[0] & cmp_imm)
186 pc += imm;
187 break;
188 case JNEBS_OPCODE: {
189 // cmp_imm is size in bytes of data to compare.
190 // pc is offset of program bytes to compare.
191 // imm is jump target offset.
192 // REG is offset of packet bytes to compare.
193 ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1);
194 ASSERT_IN_PACKET_BOUNDS(REG);
195 const uint32_t last_packet_offs = REG + cmp_imm - 1;
196 ASSERT_RETURN(last_packet_offs >= REG);
197 ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
198 if (memcmp(program + pc, packet + REG, cmp_imm))
199 pc += imm;
200 // skip past comparison bytes
201 pc += cmp_imm;
202 break;
203 }
204 }
205 break;
206 }
207 case ADD_OPCODE:
208 registers[0] += reg_num ? registers[1] : imm;
209 break;
210 case MUL_OPCODE:
211 registers[0] *= reg_num ? registers[1] : imm;
212 break;
213 case DIV_OPCODE: {
214 const uint32_t div_operand = reg_num ? registers[1] : imm;
215 ASSERT_RETURN(div_operand);
216 registers[0] /= div_operand;
217 break;
218 }
219 case AND_OPCODE:
220 registers[0] &= reg_num ? registers[1] : imm;
221 break;
222 case OR_OPCODE:
223 registers[0] |= reg_num ? registers[1] : imm;
224 break;
225 case SH_OPCODE: {
226 const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm;
227 if (shift_val > 0)
228 registers[0] <<= shift_val;
229 else
230 registers[0] >>= -shift_val;
231 break;
232 }
233 case LI_OPCODE:
234 REG = signed_imm;
235 break;
236 case EXT_OPCODE:
237 if (
238 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
239 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
240 #if LDM_EXT_OPCODE == 0
241 ENFORCE_UNSIGNED(imm) &&
242 #else
243 imm >= LDM_EXT_OPCODE &&
244 #endif
245 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
246 REG = memory[imm - LDM_EXT_OPCODE];
247 } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
248 memory[imm - STM_EXT_OPCODE] = REG;
249 } else switch (imm) {
250 case NOT_EXT_OPCODE:
251 REG = ~REG;
252 break;
253 case NEG_EXT_OPCODE:
254 REG = -REG;
255 break;
256 case SWAP_EXT_OPCODE: {
257 uint32_t tmp = REG;
258 REG = OTHER_REG;
259 OTHER_REG = tmp;
260 break;
261 }
262 case MOV_EXT_OPCODE:
263 REG = OTHER_REG;
264 break;
265 // Unknown extended opcode
266 default:
267 // Bail out
268 return PASS_PACKET;
269 }
270 break;
271 // Unknown opcode
272 default:
273 // Bail out
274 return PASS_PACKET;
275 }
276 } while (instructions_remaining--);
277 return PASS_PACKET;
278 }
279