1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Tools to interact with BPF programs."""
18
19import abc
20import collections
21import struct
22
23# This comes from syscall(2). Most architectures only support passing 6 args to
24# syscalls, but ARM supports passing 7.
25MAX_SYSCALL_ARGUMENTS = 7
26
27# The following fields were copied from <linux/bpf_common.h>:
28
29# Instruction classes
30BPF_LD = 0x00
31BPF_LDX = 0x01
32BPF_ST = 0x02
33BPF_STX = 0x03
34BPF_ALU = 0x04
35BPF_JMP = 0x05
36BPF_RET = 0x06
37BPF_MISC = 0x07
38
39# LD/LDX fields.
40# Size
41BPF_W = 0x00
42BPF_H = 0x08
43BPF_B = 0x10
44# Mode
45BPF_IMM = 0x00
46BPF_ABS = 0x20
47BPF_IND = 0x40
48BPF_MEM = 0x60
49BPF_LEN = 0x80
50BPF_MSH = 0xa0
51
52# JMP fields.
53BPF_JA = 0x00
54BPF_JEQ = 0x10
55BPF_JGT = 0x20
56BPF_JGE = 0x30
57BPF_JSET = 0x40
58
59# Source
60BPF_K = 0x00
61BPF_X = 0x08
62
63BPF_MAXINSNS = 4096
64
65# The following fields were copied from <linux/seccomp.h>:
66
67SECCOMP_RET_KILL_PROCESS = 0x80000000
68SECCOMP_RET_KILL_THREAD = 0x00000000
69SECCOMP_RET_TRAP = 0x00030000
70SECCOMP_RET_ERRNO = 0x00050000
71SECCOMP_RET_TRACE = 0x7ff00000
72SECCOMP_RET_USER_NOTIF = 0x7fc00000
73SECCOMP_RET_LOG = 0x7ffc0000
74SECCOMP_RET_ALLOW = 0x7fff0000
75
76SECCOMP_RET_ACTION_FULL = 0xffff0000
77SECCOMP_RET_DATA = 0x0000ffff
78
79
80def arg_offset(arg_index, hi=False):
81    """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
82    offsetof_args = 4 + 4 + 8
83    arg_width = 8
84    return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
85
86
87def simulate(instructions, arch, syscall_number, *args):
88    """Simulate a BPF program with the given arguments."""
89    args = ((args + (0, ) *
90             (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
91    input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
92                               syscall_number, arch, 0, *args)
93
94    register = 0
95    program_counter = 0
96    cost = 0
97    while program_counter < len(instructions):
98        ins = instructions[program_counter]
99        program_counter += 1
100        cost += 1
101        if ins.code == BPF_LD | BPF_W | BPF_ABS:
102            register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
103        elif ins.code == BPF_JMP | BPF_JA | BPF_K:
104            program_counter += ins.k
105        elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
106            if register == ins.k:
107                program_counter += ins.jt
108            else:
109                program_counter += ins.jf
110        elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
111            if register > ins.k:
112                program_counter += ins.jt
113            else:
114                program_counter += ins.jf
115        elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
116            if register >= ins.k:
117                program_counter += ins.jt
118            else:
119                program_counter += ins.jf
120        elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
121            if register & ins.k != 0:
122                program_counter += ins.jt
123            else:
124                program_counter += ins.jf
125        elif ins.code == BPF_RET:
126            if ins.k == SECCOMP_RET_KILL_PROCESS:
127                return (cost, 'KILL_PROCESS')
128            if ins.k == SECCOMP_RET_KILL_THREAD:
129                return (cost, 'KILL_THREAD')
130            if ins.k == SECCOMP_RET_TRAP:
131                return (cost, 'TRAP')
132            if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
133                return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
134            if ins.k == SECCOMP_RET_TRACE:
135                return (cost, 'TRACE')
136            if ins.k == SECCOMP_RET_USER_NOTIF:
137                return (cost, 'USER_NOTIF')
138            if ins.k == SECCOMP_RET_LOG:
139                return (cost, 'LOG')
140            if ins.k == SECCOMP_RET_ALLOW:
141                return (cost, 'ALLOW')
142            raise Exception('unknown return %#x' % ins.k)
143        else:
144            raise Exception('unknown instruction %r' % (ins, ))
145    raise Exception('out-of-bounds')
146
147
148class SockFilter(
149        collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
150    """A representation of struct sock_filter."""
151
152    __slots__ = ()
153
154    def encode(self):
155        """Return an encoded version of the SockFilter."""
156        return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
157
158
159class AbstractBlock(abc.ABC):
160    """A class that implements the visitor pattern."""
161
162    def __init__(self):
163        super().__init__()
164
165    @abc.abstractmethod
166    def accept(self, visitor):
167        pass
168
169
170class BasicBlock(AbstractBlock):
171    """A concrete implementation of AbstractBlock that has been compiled."""
172
173    def __init__(self, instructions):
174        super().__init__()
175        self._instructions = instructions
176
177    def accept(self, visitor):
178        if visitor.visited(self):
179            return
180        visitor.visit(self)
181
182    @property
183    def instructions(self):
184        return self._instructions
185
186    @property
187    def opcodes(self):
188        return b''.join(i.encode() for i in self._instructions)
189
190    def __eq__(self, o):
191        if not isinstance(o, BasicBlock):
192            return False
193        return self._instructions == o._instructions
194
195
196class KillProcess(BasicBlock):
197    """A BasicBlock that unconditionally returns KILL_PROCESS."""
198
199    def __init__(self):
200        super().__init__(
201            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
202
203
204class KillThread(BasicBlock):
205    """A BasicBlock that unconditionally returns KILL_THREAD."""
206
207    def __init__(self):
208        super().__init__(
209            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
210
211
212class Trap(BasicBlock):
213    """A BasicBlock that unconditionally returns TRAP."""
214
215    def __init__(self):
216        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
217
218
219class Trace(BasicBlock):
220    """A BasicBlock that unconditionally returns TRACE."""
221
222    def __init__(self):
223        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
224
225
226class UserNotify(BasicBlock):
227    """A BasicBlock that unconditionally returns USER_NOTIF."""
228
229    def __init__(self):
230        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)])
231
232
233class Log(BasicBlock):
234    """A BasicBlock that unconditionally returns LOG."""
235
236    def __init__(self):
237        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
238
239
240class ReturnErrno(BasicBlock):
241    """A BasicBlock that unconditionally returns the specified errno."""
242
243    def __init__(self, errno):
244        super().__init__([
245            SockFilter(BPF_RET, 0x00, 0x00,
246                       SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
247        ])
248        self.errno = errno
249
250
251class Allow(BasicBlock):
252    """A BasicBlock that unconditionally returns ALLOW."""
253
254    def __init__(self):
255        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
256
257
258class ValidateArch(AbstractBlock):
259    """An AbstractBlock that validates the architecture."""
260
261    def __init__(self, next_block):
262        super().__init__()
263        self.next_block = next_block
264
265    def accept(self, visitor):
266        if visitor.visited(self):
267            return
268        self.next_block.accept(visitor)
269        visitor.visit(self)
270
271
272class SyscallEntry(AbstractBlock):
273    """An abstract block that represents a syscall comparison in a DAG."""
274
275    def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
276        super().__init__()
277        self.op = op
278        self.syscall_number = syscall_number
279        self.jt = jt
280        self.jf = jf
281
282    def __lt__(self, o):
283        # Defined because we want to compare tuples that contain SyscallEntries.
284        return False
285
286    def __gt__(self, o):
287        # Defined because we want to compare tuples that contain SyscallEntries.
288        return False
289
290    def accept(self, visitor):
291        if visitor.visited(self):
292            return
293        self.jt.accept(visitor)
294        self.jf.accept(visitor)
295        visitor.visit(self)
296
297    def __lt__(self, o):
298        # Defined because we want to compare tuples that contain SyscallEntries.
299        return False
300
301    def __gt__(self, o):
302        # Defined because we want to compare tuples that contain SyscallEntries.
303        return False
304
305
306class WideAtom(AbstractBlock):
307    """A BasicBlock that represents a 32-bit wide atom."""
308
309    def __init__(self, arg_offset, op, value, jt, jf):
310        super().__init__()
311        self.arg_offset = arg_offset
312        self.op = op
313        self.value = value
314        self.jt = jt
315        self.jf = jf
316
317    def accept(self, visitor):
318        if visitor.visited(self):
319            return
320        self.jt.accept(visitor)
321        self.jf.accept(visitor)
322        visitor.visit(self)
323
324
325class Atom(AbstractBlock):
326    """A BasicBlock that represents an atom (a simple comparison operation)."""
327
328    def __init__(self, arg_index, op, value, jt, jf):
329        super().__init__()
330        if op == '==':
331            op = BPF_JEQ
332        elif op == '!=':
333            op = BPF_JEQ
334            jt, jf = jf, jt
335        elif op == '>':
336            op = BPF_JGT
337        elif op == '<=':
338            op = BPF_JGT
339            jt, jf = jf, jt
340        elif op == '>=':
341            op = BPF_JGE
342        elif op == '<':
343            op = BPF_JGE
344            jt, jf = jf, jt
345        elif op == '&':
346            op = BPF_JSET
347        elif op == 'in':
348            op = BPF_JSET
349            # The mask is negated, so the comparison will be true when the
350            # argument includes a flag that wasn't listed in the original
351            # (non-negated) mask. This would be the failure case, so we switch
352            # |jt| and |jf|.
353            value = (~value) & ((1 << 64) - 1)
354            jt, jf = jf, jt
355        else:
356            raise Exception('Unknown operator %s' % op)
357
358        self.arg_index = arg_index
359        self.op = op
360        self.jt = jt
361        self.jf = jf
362        self.value = value
363
364    def accept(self, visitor):
365        if visitor.visited(self):
366            return
367        self.jt.accept(visitor)
368        self.jf.accept(visitor)
369        visitor.visit(self)
370
371
372class AbstractVisitor(abc.ABC):
373    """An abstract visitor."""
374
375    def __init__(self):
376        self._visited = set()
377
378    def visited(self, block):
379        if id(block) in self._visited:
380            return True
381        self._visited.add(id(block))
382        return False
383
384    def process(self, block):
385        block.accept(self)
386        return block
387
388    def visit(self, block):
389        if isinstance(block, KillProcess):
390            self.visitKillProcess(block)
391        elif isinstance(block, KillThread):
392            self.visitKillThread(block)
393        elif isinstance(block, Trap):
394            self.visitTrap(block)
395        elif isinstance(block, ReturnErrno):
396            self.visitReturnErrno(block)
397        elif isinstance(block, Trace):
398            self.visitTrace(block)
399        elif isinstance(block, UserNotify):
400            self.visitUserNotify(block)
401        elif isinstance(block, Log):
402            self.visitLog(block)
403        elif isinstance(block, Allow):
404            self.visitAllow(block)
405        elif isinstance(block, BasicBlock):
406            self.visitBasicBlock(block)
407        elif isinstance(block, ValidateArch):
408            self.visitValidateArch(block)
409        elif isinstance(block, SyscallEntry):
410            self.visitSyscallEntry(block)
411        elif isinstance(block, WideAtom):
412            self.visitWideAtom(block)
413        elif isinstance(block, Atom):
414            self.visitAtom(block)
415        else:
416            raise Exception('Unknown block type: %r' % block)
417
418    @abc.abstractmethod
419    def visitKillProcess(self, block):
420        pass
421
422    @abc.abstractmethod
423    def visitKillThread(self, block):
424        pass
425
426    @abc.abstractmethod
427    def visitTrap(self, block):
428        pass
429
430    @abc.abstractmethod
431    def visitReturnErrno(self, block):
432        pass
433
434    @abc.abstractmethod
435    def visitTrace(self, block):
436        pass
437
438    @abc.abstractmethod
439    def visitUserNotify(self, block):
440        pass
441
442    @abc.abstractmethod
443    def visitLog(self, block):
444        pass
445
446    @abc.abstractmethod
447    def visitAllow(self, block):
448        pass
449
450    @abc.abstractmethod
451    def visitBasicBlock(self, block):
452        pass
453
454    @abc.abstractmethod
455    def visitValidateArch(self, block):
456        pass
457
458    @abc.abstractmethod
459    def visitSyscallEntry(self, block):
460        pass
461
462    @abc.abstractmethod
463    def visitWideAtom(self, block):
464        pass
465
466    @abc.abstractmethod
467    def visitAtom(self, block):
468        pass
469
470
471class CopyingVisitor(AbstractVisitor):
472    """A visitor that copies Blocks."""
473
474    def __init__(self):
475        super().__init__()
476        self._mapping = {}
477
478    def process(self, block):
479        self._mapping = {}
480        block.accept(self)
481        return self._mapping[id(block)]
482
483    def visitKillProcess(self, block):
484        assert id(block) not in self._mapping
485        self._mapping[id(block)] = KillProcess()
486
487    def visitKillThread(self, block):
488        assert id(block) not in self._mapping
489        self._mapping[id(block)] = KillThread()
490
491    def visitTrap(self, block):
492        assert id(block) not in self._mapping
493        self._mapping[id(block)] = Trap()
494
495    def visitReturnErrno(self, block):
496        assert id(block) not in self._mapping
497        self._mapping[id(block)] = ReturnErrno(block.errno)
498
499    def visitTrace(self, block):
500        assert id(block) not in self._mapping
501        self._mapping[id(block)] = Trace()
502
503    def visitUserNotify(self, block):
504        assert id(block) not in self._mapping
505        self._mapping[id(block)] = UserNotify()
506
507    def visitLog(self, block):
508        assert id(block) not in self._mapping
509        self._mapping[id(block)] = Log()
510
511    def visitAllow(self, block):
512        assert id(block) not in self._mapping
513        self._mapping[id(block)] = Allow()
514
515    def visitBasicBlock(self, block):
516        assert id(block) not in self._mapping
517        self._mapping[id(block)] = BasicBlock(block.instructions)
518
519    def visitValidateArch(self, block):
520        assert id(block) not in self._mapping
521        self._mapping[id(block)] = ValidateArch(
522            block.arch, self._mapping[id(block.next_block)])
523
524    def visitSyscallEntry(self, block):
525        assert id(block) not in self._mapping
526        self._mapping[id(block)] = SyscallEntry(
527            block.syscall_number,
528            self._mapping[id(block.jt)],
529            self._mapping[id(block.jf)],
530            op=block.op)
531
532    def visitWideAtom(self, block):
533        assert id(block) not in self._mapping
534        self._mapping[id(block)] = WideAtom(
535            block.arg_offset, block.op, block.value, self._mapping[id(
536                block.jt)], self._mapping[id(block.jf)])
537
538    def visitAtom(self, block):
539        assert id(block) not in self._mapping
540        self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
541                                        self._mapping[id(block.jt)],
542                                        self._mapping[id(block.jf)])
543
544
545class LoweringVisitor(CopyingVisitor):
546    """A visitor that lowers Atoms into WideAtoms."""
547
548    def __init__(self, *, arch):
549        super().__init__()
550        self._bits = arch.bits
551
552    def visitAtom(self, block):
553        assert id(block) not in self._mapping
554
555        lo = block.value & 0xFFFFFFFF
556        hi = (block.value >> 32) & 0xFFFFFFFF
557
558        lo_block = WideAtom(
559            arg_offset(block.arg_index, False), block.op, lo,
560            self._mapping[id(block.jt)], self._mapping[id(block.jf)])
561
562        if self._bits == 32:
563            self._mapping[id(block)] = lo_block
564            return
565
566        if block.op in (BPF_JGE, BPF_JGT):
567            # hi_1,lo_1 <op> hi_2,lo_2
568            #
569            # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
570            if hi == 0:
571                # Special case: it's not needed to check whether |hi_1 == hi_2|,
572                # because it's true iff the JGT test fails.
573                self._mapping[id(block)] = WideAtom(
574                    arg_offset(block.arg_index, True), BPF_JGT, hi,
575                    self._mapping[id(block.jt)], lo_block)
576                return
577            hi_eq_block = WideAtom(
578                arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
579                self._mapping[id(block.jf)])
580            self._mapping[id(block)] = WideAtom(
581                arg_offset(block.arg_index, True), BPF_JGT, hi,
582                self._mapping[id(block.jt)], hi_eq_block)
583            return
584        if block.op == BPF_JSET:
585            # hi_1,lo_1 & hi_2,lo_2
586            #
587            # hi_1 & hi_2 || lo_1 & lo_2
588            if hi == 0:
589                # Special case: |hi_1 & hi_2| will never be True, so jump
590                # directly into the |lo_1 & lo_2| case.
591                self._mapping[id(block)] = lo_block
592                return
593            self._mapping[id(block)] = WideAtom(
594                arg_offset(block.arg_index, True), block.op, hi,
595                self._mapping[id(block.jt)], lo_block)
596            return
597
598        assert block.op == BPF_JEQ, block.op
599
600        # hi_1,lo_1 == hi_2,lo_2
601        #
602        # hi_1 == hi_2 && lo_1 == lo_2
603        self._mapping[id(block)] = WideAtom(
604            arg_offset(block.arg_index, True), block.op, hi, lo_block,
605            self._mapping[id(block.jf)])
606
607
608class FlatteningVisitor:
609    """A visitor that flattens a DAG of Block objects."""
610
611    def __init__(self, *, arch, kill_action):
612        self._visited = set()
613        self._kill_action = kill_action
614        self._instructions = []
615        self._arch = arch
616        self._offsets = {}
617
618    @property
619    def result(self):
620        return BasicBlock(self._instructions)
621
622    def _distance(self, block):
623        distance = self._offsets[id(block)] + len(self._instructions)
624        assert distance >= 0
625        return distance
626
627    def _emit_load_arg(self, offset):
628        return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
629
630    def _emit_jmp(self, op, value, jt_distance, jf_distance):
631        if jt_distance < 0x100 and jf_distance < 0x100:
632            return [
633                SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
634                           value),
635            ]
636        if jt_distance + 1 < 0x100:
637            return [
638                SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
639                SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
640            ]
641        if jf_distance + 1 < 0x100:
642            return [
643                SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
644                SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
645            ]
646        return [
647            SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
648            SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
649            SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
650        ]
651
652    def visited(self, block):
653        if id(block) in self._visited:
654            return True
655        self._visited.add(id(block))
656        return False
657
658    def visit(self, block):
659        assert id(block) not in self._offsets
660
661        if isinstance(block, BasicBlock):
662            instructions = block.instructions
663        elif isinstance(block, ValidateArch):
664            instructions = [
665                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
666                SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
667                           self._distance(block.next_block) + 1, 0,
668                           self._arch.arch_nr),
669            ] + self._kill_action.instructions + [
670                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
671            ]
672        elif isinstance(block, SyscallEntry):
673            instructions = self._emit_jmp(block.op, block.syscall_number,
674                                          self._distance(block.jt),
675                                          self._distance(block.jf))
676        elif isinstance(block, WideAtom):
677            instructions = (
678                self._emit_load_arg(block.arg_offset) + self._emit_jmp(
679                    block.op, block.value, self._distance(block.jt),
680                    self._distance(block.jf)))
681        else:
682            raise Exception('Unknown block type: %r' % block)
683
684        self._instructions = instructions + self._instructions
685        self._offsets[id(block)] = -len(self._instructions)
686        return
687
688
689class ArgFilterForwardingVisitor:
690    """A visitor that forwards visitation to all arg filters."""
691
692    def __init__(self, visitor):
693        self._visited = set()
694        self.visitor = visitor
695
696    def visited(self, block):
697        if id(block) in self._visited:
698            return True
699        self._visited.add(id(block))
700        return False
701
702    def visit(self, block):
703        # All arg filters are BasicBlocks.
704        if not isinstance(block, BasicBlock):
705            return
706        # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
707        # want to visit them just yet.
708        if (isinstance(block, KillProcess) or isinstance(block, KillThread)
709                or isinstance(block, Trap) or isinstance(block, ReturnErrno)
710                or isinstance(block, Trace) or isinstance(block, UserNotify)
711                or isinstance(block, Log) or isinstance(block, Allow)):
712            return
713        block.accept(self.visitor)
714