/* * Copyright 2016, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "apf_interpreter.h" #include // For memcmp #include "apf.h" // Return code indicating "packet" should accepted. #define PASS_PACKET 1 // Return code indicating "packet" should be dropped. #define DROP_PACKET 0 // Verify an internal condition and accept packet if it fails. #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error. // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding // superfluous ">= 0" with unsigned expressions generates compile warnings. #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c)) /** * Runs a packet filtering program over a packet. * * @param program the program bytecode. * @param program_len the length of {@code apf_program} in bytes. * @param packet the packet bytes, starting from the 802.3 header and not * including any CRC bytes at the end. * @param packet_len the length of {@code packet} in bytes. * @param filter_age the number of seconds since the filter was programmed. * * @return non-zero if packet should be passed to AP, zero if * packet should be dropped. */ int accept_packet(const uint8_t* program, uint32_t program_len, const uint8_t* packet, uint32_t packet_len, uint32_t filter_age) { // Is offset within program bounds? #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len) // Is offset within packet bounds? #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len) // Accept packet if not within program bounds #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p)) // Accept packet if not within packet bounds #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p)) // Program counter. uint32_t pc = 0; // Accept packet if not within program or not ahead of program counter #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc) // Memory slot values. uint32_t memory[MEMORY_ITEMS] = {}; // Fill in pre-filled memory slot values. memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len; memory[MEMORY_OFFSET_FILTER_AGE] = filter_age; ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE); // Only populate if IP version is IPv4. if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) { memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4; } // Register values. uint32_t registers[2] = {}; // Count of instructions remaining to execute. This is done to ensure an // upper bound on execution time. It should never be hit and is only for // safety. Initialize to the number of bytes in the program which is an // upper bound on the number of instructions in the program. uint32_t instructions_remaining = program_len; do { if (pc == program_len) { return PASS_PACKET; } else if (pc == (program_len + 1)) { return DROP_PACKET; } ASSERT_IN_PROGRAM_BOUNDS(pc); const uint8_t bytecode = program[pc++]; const uint32_t opcode = EXTRACT_OPCODE(bytecode); const uint32_t reg_num = EXTRACT_REGISTER(bytecode); #define REG (registers[reg_num]) #define OTHER_REG (registers[reg_num ^ 1]) // All instructions have immediate fields, so load them now. const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode); uint32_t imm = 0; int32_t signed_imm = 0; if (len_field != 0) { const uint32_t imm_len = 1 << (len_field - 1); ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1); uint32_t i; for (i = 0; i < imm_len; i++) imm = (imm << 8) | program[pc++]; // Sign extend imm into signed_imm. signed_imm = imm << ((4 - imm_len) * 8); signed_imm >>= (4 - imm_len) * 8; } switch (opcode) { case LDB_OPCODE: case LDH_OPCODE: case LDW_OPCODE: case LDBX_OPCODE: case LDHX_OPCODE: case LDWX_OPCODE: { uint32_t offs = imm; if (opcode >= LDBX_OPCODE) { // Note: this can overflow and actually decrease offs. offs += registers[1]; } ASSERT_IN_PACKET_BOUNDS(offs); uint32_t load_size; switch (opcode) { case LDB_OPCODE: case LDBX_OPCODE: load_size = 1; break; case LDH_OPCODE: case LDHX_OPCODE: load_size = 2; break; case LDW_OPCODE: case LDWX_OPCODE: load_size = 4; break; // Immediately enclosing switch statement guarantees // opcode cannot be any other value. } const uint32_t end_offs = offs + (load_size - 1); // Catch overflow/wrap-around. ASSERT_RETURN(end_offs >= offs); ASSERT_IN_PACKET_BOUNDS(end_offs); uint32_t val = 0; while (load_size--) val = (val << 8) | packet[offs++]; REG = val; break; } case JMP_OPCODE: // This can jump backwards. Infinite looping prevented by instructions_remaining. pc += imm; break; case JEQ_OPCODE: case JNE_OPCODE: case JGT_OPCODE: case JLT_OPCODE: case JSET_OPCODE: case JNEBS_OPCODE: { // Load second immediate field. uint32_t cmp_imm = 0; if (reg_num == 1) { cmp_imm = registers[1]; } else if (len_field != 0) { uint32_t cmp_imm_len = 1 << (len_field - 1); ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1); uint32_t i; for (i = 0; i < cmp_imm_len; i++) cmp_imm = (cmp_imm << 8) | program[pc++]; } switch (opcode) { case JEQ_OPCODE: if (registers[0] == cmp_imm) pc += imm; break; case JNE_OPCODE: if (registers[0] != cmp_imm) pc += imm; break; case JGT_OPCODE: if (registers[0] > cmp_imm) pc += imm; break; case JLT_OPCODE: if (registers[0] < cmp_imm) pc += imm; break; case JSET_OPCODE: if (registers[0] & cmp_imm) pc += imm; break; case JNEBS_OPCODE: { // cmp_imm is size in bytes of data to compare. // pc is offset of program bytes to compare. // imm is jump target offset. // REG is offset of packet bytes to compare. ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1); ASSERT_IN_PACKET_BOUNDS(REG); const uint32_t last_packet_offs = REG + cmp_imm - 1; ASSERT_RETURN(last_packet_offs >= REG); ASSERT_IN_PACKET_BOUNDS(last_packet_offs); if (memcmp(program + pc, packet + REG, cmp_imm)) pc += imm; // skip past comparison bytes pc += cmp_imm; break; } } break; } case ADD_OPCODE: registers[0] += reg_num ? registers[1] : imm; break; case MUL_OPCODE: registers[0] *= reg_num ? registers[1] : imm; break; case DIV_OPCODE: { const uint32_t div_operand = reg_num ? registers[1] : imm; ASSERT_RETURN(div_operand); registers[0] /= div_operand; break; } case AND_OPCODE: registers[0] &= reg_num ? registers[1] : imm; break; case OR_OPCODE: registers[0] |= reg_num ? registers[1] : imm; break; case SH_OPCODE: { const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm; if (shift_val > 0) registers[0] <<= shift_val; else registers[0] >>= -shift_val; break; } case LI_OPCODE: REG = signed_imm; break; case EXT_OPCODE: if ( // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result, // instead just enforce that imm is unsigned (so it's always greater or equal to 0). #if LDM_EXT_OPCODE == 0 ENFORCE_UNSIGNED(imm) && #else imm >= LDM_EXT_OPCODE && #endif imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) { REG = memory[imm - LDM_EXT_OPCODE]; } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) { memory[imm - STM_EXT_OPCODE] = REG; } else switch (imm) { case NOT_EXT_OPCODE: REG = ~REG; break; case NEG_EXT_OPCODE: REG = -REG; break; case SWAP_EXT_OPCODE: { uint32_t tmp = REG; REG = OTHER_REG; OTHER_REG = tmp; break; } case MOV_EXT_OPCODE: REG = OTHER_REG; break; // Unknown extended opcode default: // Bail out return PASS_PACKET; } break; // Unknown opcode default: // Bail out return PASS_PACKET; } } while (instructions_remaining--); return PASS_PACKET; }