1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdint.h>
18 #include <stdio.h>
19 #include <stdarg.h>
20 
21 typedef enum { false, true } bool;
22 
23 #include "v7/apf_defs.h"
24 #include "v7/apf.h"
25 #include "disassembler.h"
26 
27 // If "c" is of a signed type, generate a compile warning that gets promoted to an error.
28 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
29 // superfluous ">= 0" with unsigned expressions generates compile warnings.
30 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
31 
32 char print_buf[1024];
33 char* buf_ptr;
34 int buf_remain;
35 bool v6_mode = false;
36 
37 __attribute__ ((format (printf, 1, 2) ))
bprintf(const char * format,...)38 static void bprintf(const char* format, ...) {
39     va_list args;
40     va_start(args, format);
41     int ret = vsnprintf(buf_ptr, buf_remain, format, args);
42     va_end(args);
43     if (ret < 0) return;
44     if (ret >= buf_remain) ret = buf_remain;
45     buf_ptr += ret;
46     buf_remain -= ret;
47 }
48 
print_opcode(const char * opcode)49 static void print_opcode(const char* opcode) {
50     bprintf("%-12s", opcode);
51 }
52 
53 // Mapping from opcode number to opcode name.
54 static const char* opcode_names [] = {
55     [LDB_OPCODE] = "ldb",
56     [LDH_OPCODE] = "ldh",
57     [LDW_OPCODE] = "ldw",
58     [LDBX_OPCODE] = "ldbx",
59     [LDHX_OPCODE] = "ldhx",
60     [LDWX_OPCODE] = "ldwx",
61     [ADD_OPCODE] = "add",
62     [MUL_OPCODE] = "mul",
63     [DIV_OPCODE] = "div",
64     [AND_OPCODE] = "and",
65     [OR_OPCODE] = "or",
66     [SH_OPCODE] = "sh",
67     [LI_OPCODE] = "li",
68     [JMP_OPCODE] = "jmp",
69     [JEQ_OPCODE] = "jeq",
70     [JNE_OPCODE] = "jne",
71     [JGT_OPCODE] = "jgt",
72     [JLT_OPCODE] = "jlt",
73     [JSET_OPCODE] = "jset",
74     [JBSMATCH_OPCODE] = NULL,
75     [LDDW_OPCODE] = "lddw",
76     [STDW_OPCODE] = "stdw",
77     [WRITE_OPCODE] = "write",
78     [JNSET_OPCODE] = "jnset",
79 };
80 
print_jump_target(uint32_t target,uint32_t program_len)81 static void print_jump_target(uint32_t target, uint32_t program_len) {
82     if (target == program_len) {
83         bprintf("PASS");
84     } else if (target == program_len + 1) {
85         bprintf("DROP");
86     } else {
87         bprintf("%u", target);
88     }
89 }
90 
apf_disassemble(const uint8_t * program,uint32_t program_len,uint32_t * const ptr2pc)91 const char* apf_disassemble(const uint8_t* program, uint32_t program_len, uint32_t* const ptr2pc) {
92     buf_ptr = print_buf;
93     buf_remain = sizeof(print_buf);
94     if (*ptr2pc > program_len + 1) {
95         bprintf("pc is overflow: pc %d, program_len: %d", *ptr2pc, program_len);
96         return print_buf;
97     }
98 
99     bprintf("%8u: ", *ptr2pc);
100 
101     if (*ptr2pc == program_len) {
102         bprintf("PASS");
103         ++(*ptr2pc);
104         return print_buf;
105     }
106 
107     if (*ptr2pc == program_len + 1) {
108         bprintf("DROP");
109         ++(*ptr2pc);
110         return print_buf;
111     }
112 
113     const uint8_t bytecode = program[(*ptr2pc)++];
114     const uint32_t opcode = EXTRACT_OPCODE(bytecode);
115 
116 #define PRINT_OPCODE() print_opcode(opcode_names[opcode])
117 #define DECODE_IMM(length)  ({                                        \
118     uint32_t value = 0;                                               \
119     for (uint32_t i = 0; i < (length) && *ptr2pc < program_len; i++)  \
120         value = (value << 8) | program[(*ptr2pc)++];                  \
121     value;})
122 
123     const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
124     // All instructions have immediate fields, so load them now.
125     const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
126     uint32_t imm = 0;
127     int32_t signed_imm = 0;
128     if (len_field != 0) {
129         const uint32_t imm_len = 1 << (len_field - 1);
130         imm = DECODE_IMM(imm_len);
131         // Sign extend imm into signed_imm.
132         signed_imm = imm << ((4 - imm_len) * 8);
133         signed_imm >>= (4 - imm_len) * 8;
134     }
135     switch (opcode) {
136         case PASSDROP_OPCODE:
137             if (reg_num == 0) {
138                 print_opcode("pass");
139             } else {
140                 print_opcode("drop");
141             }
142             if (imm > 0) {
143                 bprintf("counter=%d", imm);
144             }
145             break;
146         case LDB_OPCODE:
147         case LDH_OPCODE:
148         case LDW_OPCODE:
149             PRINT_OPCODE();
150             bprintf("r%d, [%u]", reg_num, imm);
151             break;
152         case LDBX_OPCODE:
153         case LDHX_OPCODE:
154         case LDWX_OPCODE:
155             PRINT_OPCODE();
156             if (imm) {
157                 bprintf("r%d, [r1+%u]", reg_num, imm);
158             } else {
159                 bprintf("r%d, [r1]", reg_num);
160             }
161             break;
162         case JMP_OPCODE:
163             if (reg_num == 0) {
164                 PRINT_OPCODE();
165                 print_jump_target(*ptr2pc + imm, program_len);
166             } else {
167                 v6_mode = true;
168                 print_opcode("data");
169                 bprintf("%d, ", imm);
170                 uint32_t len = imm;
171                 while (len--) bprintf("%02x", program[(*ptr2pc)++]);
172             }
173             break;
174         case JEQ_OPCODE:
175         case JNE_OPCODE:
176         case JGT_OPCODE:
177         case JLT_OPCODE:
178         case JSET_OPCODE:
179         case JNSET_OPCODE: {
180             PRINT_OPCODE();
181             bprintf("r0, ");
182             // Load second immediate field.
183             if (reg_num == 1) {
184                 bprintf("r1, ");
185             } else if (len_field == 0) {
186                 bprintf("0, ");
187             } else {
188                 uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
189                 bprintf("0x%x, ", cmp_imm);
190             }
191             print_jump_target(*ptr2pc + imm, program_len);
192             break;
193         }
194         case JBSMATCH_OPCODE: {
195             if (reg_num == 0) {
196                 print_opcode("jbsne");
197             } else {
198                 print_opcode("jbseq");
199             }
200             bprintf("r0, ");
201             const uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
202             const uint32_t cnt = (cmp_imm >> 11) + 1; // 1+, up to 32 fits in u16
203             const uint32_t len = cmp_imm & 2047; // 0..2047
204             bprintf("0x%x, ", len);
205             print_jump_target(*ptr2pc + imm + cnt * len, program_len);
206             bprintf(", ");
207             if (cnt > 1) {
208                 bprintf("{ ");
209             }
210             for (uint32_t i = 0; i < cnt; ++i) {
211                 for (uint32_t j = 0; j < len; ++j) {
212                     uint8_t byte = program[(*ptr2pc)++];
213                     bprintf("%02x", byte);
214                 }
215                 if (i != cnt - 1) {
216                     bprintf(", ");
217                 }
218             }
219             if (cnt > 1) {
220                 bprintf(" }");
221             }
222             break;
223         }
224         case SH_OPCODE:
225             PRINT_OPCODE();
226             if (reg_num) {
227                 bprintf("r0, r1");
228             } else {
229                 bprintf("r0, %d", signed_imm);
230             }
231             break;
232         case ADD_OPCODE:
233         case MUL_OPCODE:
234         case DIV_OPCODE:
235         case AND_OPCODE:
236         case OR_OPCODE:
237             PRINT_OPCODE();
238             if (reg_num) {
239                 bprintf("r0, r1");
240             } else if (!imm && opcode == DIV_OPCODE) {
241                 bprintf("pass (div 0)");
242             } else {
243                 bprintf("r0, %u", imm);
244             }
245             break;
246         case LI_OPCODE:
247             PRINT_OPCODE();
248             bprintf("r%d, %d", reg_num, signed_imm);
249             break;
250         case EXT_OPCODE:
251             if (
252 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
253 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
254 #if LDM_EXT_OPCODE == 0
255                 ENFORCE_UNSIGNED(imm) &&
256 #else
257                 imm >= LDM_EXT_OPCODE &&
258 #endif
259                 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
260                 print_opcode("ldm");
261                 bprintf("r%d, m[%u]", reg_num, imm - LDM_EXT_OPCODE);
262             } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
263                 print_opcode("stm");
264                 bprintf("r%d, m[%u]", reg_num, imm - STM_EXT_OPCODE);
265             } else switch (imm) {
266                 case NOT_EXT_OPCODE:
267                     print_opcode("not");
268                     bprintf("r%d", reg_num);
269                     break;
270                 case NEG_EXT_OPCODE:
271                     print_opcode("neg");
272                     bprintf("r%d", reg_num);
273                     break;
274                 case SWAP_EXT_OPCODE:
275                     print_opcode("swap");
276                     break;
277                 case MOV_EXT_OPCODE:
278                     print_opcode("mov");
279                     bprintf("r%d, r%d", reg_num, reg_num ^ 1);
280                     break;
281                 case ALLOCATE_EXT_OPCODE:
282                     print_opcode("allocate");
283                     if (reg_num == 0) {
284                         bprintf("r%d", reg_num);
285                     } else {
286                         uint32_t alloc_len = DECODE_IMM(2);
287                         bprintf("%d", alloc_len);
288                     }
289                     break;
290                 case TRANSMIT_EXT_OPCODE:
291                     print_opcode(reg_num ? "transmitudp" : "transmit");
292                     u8 ip_ofs = DECODE_IMM(1);
293                     u8 csum_ofs = DECODE_IMM(1);
294                     if (csum_ofs < 255) {
295                         u8 csum_start = DECODE_IMM(1);
296                         u16 partial_csum = DECODE_IMM(2);
297                         bprintf("ip_ofs=%d, csum_ofs=%d, csum_start=%d, partial_csum=0x%04x",
298                                 ip_ofs, csum_ofs, csum_start, partial_csum);
299                     } else {
300                         bprintf("ip_ofs=%d", ip_ofs);
301                     }
302                     break;
303                 case EWRITE1_EXT_OPCODE: print_opcode("ewrite1"); bprintf("r%d", reg_num); break;
304                 case EWRITE2_EXT_OPCODE: print_opcode("ewrite2"); bprintf("r%d", reg_num); break;
305                 case EWRITE4_EXT_OPCODE: print_opcode("ewrite4"); bprintf("r%d", reg_num); break;
306                 case EPKTDATACOPYIMM_EXT_OPCODE:
307                 case EPKTDATACOPYR1_EXT_OPCODE: {
308                     if (reg_num == 0) {
309                         print_opcode("epktcopy");
310                     } else {
311                         print_opcode("edatacopy");
312                     }
313                     if (imm == EPKTDATACOPYIMM_EXT_OPCODE) {
314                         uint32_t len = DECODE_IMM(1);
315                         bprintf(" src=r0, len=%d", len);
316                     } else {
317                         bprintf(" src=r0, len=r1");
318                     }
319 
320                     break;
321                 }
322                 case JDNSQMATCH_EXT_OPCODE:       // 43
323                 case JDNSAMATCH_EXT_OPCODE:       // 44
324                 case JDNSQMATCHSAFE_EXT_OPCODE:   // 45
325                 case JDNSAMATCHSAFE_EXT_OPCODE: { // 46
326                     uint32_t offs = DECODE_IMM(1 << (len_field - 1));
327                     int qtype = -1;
328                     switch(imm) {
329                         case JDNSQMATCH_EXT_OPCODE:
330                             print_opcode(reg_num ? "jdnsqeq" : "jdnsqne");
331                             qtype = DECODE_IMM(1);
332                             break;
333                         case JDNSQMATCHSAFE_EXT_OPCODE:
334                             print_opcode(reg_num ? "jdnsqeqsafe" : "jdnsqnesafe");
335                             qtype = DECODE_IMM(1);
336                             break;
337                         case JDNSAMATCH_EXT_OPCODE:
338                             print_opcode(reg_num ? "jdnsaeq" : "jdnsane"); break;
339                         case JDNSAMATCHSAFE_EXT_OPCODE:
340                             print_opcode(reg_num ? "jdnsaeqsafe" : "jdnsanesafe"); break;
341                         default:
342                             bprintf("unknown_ext %u", imm); break;
343                     }
344                     bprintf("r0, ");
345                     uint32_t end = *ptr2pc;
346                     while (end + 1 < program_len && !(program[end] == 0 && program[end + 1] == 0)) {
347                         end++;
348                     }
349                     end += 2;
350                     print_jump_target(end + offs, program_len);
351                     bprintf(", ");
352                     if (imm == JDNSQMATCH_EXT_OPCODE || imm == JDNSQMATCHSAFE_EXT_OPCODE) {
353                         bprintf("%d, ", qtype);
354                     }
355                     while (*ptr2pc < end) {
356                         uint8_t byte = program[(*ptr2pc)++];
357                         // values < 0x40 could be lengths, but - and 0..9 are in practice usually
358                         // too long to be lengths so print them as characters. All other chars < 0x40
359                         // are not valid in dns character.
360                         if (byte == '-' || (byte >= '0' && byte <= '9') || byte >= 0x40) {
361                             bprintf("%c", byte);
362                         } else {
363                             bprintf("(%d)", byte);
364                         }
365                     }
366                     break;
367                 }
368                 case JONEOF_EXT_OPCODE: {
369                     const uint32_t imm_len = 1 << (len_field - 1);
370                     uint32_t jump_offs = DECODE_IMM(imm_len);
371                     uint8_t imm3 = DECODE_IMM(1);
372                     bool jmp = imm3 & 1;
373                     uint8_t len = ((imm3 >> 1) & 3) + 1;
374                     uint8_t cnt = (imm3 >> 3) + 2;
375                     if (jmp) {
376                         print_opcode("jnoneof");
377                     } else {
378                         print_opcode("joneof");
379                     }
380                     bprintf("r%d, ", reg_num);
381                     print_jump_target(*ptr2pc + jump_offs + cnt * len, program_len);
382                     bprintf(", { ");
383                     while (cnt--) {
384                         uint32_t v = DECODE_IMM(len);
385                         if (cnt) {
386                             bprintf("%d, ", v);
387                         } else {
388                             bprintf("%d ", v);
389                         }
390                     }
391                     bprintf("}");
392                     break;
393                 }
394                 case EXCEPTIONBUFFER_EXT_OPCODE: {
395                     uint32_t buf_size = DECODE_IMM(2);
396                     print_opcode("debugbuf");
397                     bprintf("size=%d", buf_size);
398                     break;
399                 }
400                 default:
401                     bprintf("unknown_ext %u", imm);
402                     break;
403             }
404             break;
405         case LDDW_OPCODE:
406         case STDW_OPCODE:
407             PRINT_OPCODE();
408             if (v6_mode) {
409                 if (opcode == LDDW_OPCODE) {
410                     bprintf("r%u, counter=%d", reg_num, imm);
411                 } else {
412                     bprintf("counter=%d, r%u", imm, reg_num);
413                 }
414             } else {
415                 if (signed_imm > 0) {
416                     bprintf("r%u, [r%u+%d]", reg_num, reg_num ^ 1, signed_imm);
417                 } else if (signed_imm < 0) {
418                     bprintf("r%u, [r%u-%d]", reg_num, reg_num ^ 1, -signed_imm);
419                 } else {
420                     bprintf("r%u, [r%u]", reg_num, reg_num ^ 1);
421                 }
422             }
423             break;
424         case WRITE_OPCODE: {
425             PRINT_OPCODE();
426             uint32_t write_len = 1 << (len_field - 1);
427             if (write_len > 0) {
428                 bprintf("0x");
429             }
430             for (uint32_t i = 0; i < write_len; ++i) {
431                 uint8_t byte =
432                     (uint8_t) ((imm >> (write_len - 1 - i) * 8) & 0xff);
433                 bprintf("%02x", byte);
434 
435             }
436             break;
437         }
438         case PKTDATACOPY_OPCODE: {
439             if (reg_num == 0) {
440                 print_opcode("pktcopy");
441             } else {
442                 print_opcode("datacopy");
443             }
444             uint32_t src_offs = imm;
445             uint32_t copy_len = DECODE_IMM(1);
446             bprintf("src=%d, len=%d", src_offs, copy_len);
447             break;
448         }
449         // Unknown opcode
450         default:
451             bprintf("unknown %u", opcode);
452             break;
453     }
454     return print_buf;
455 }
456