1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""A BPF compiler for the Minijail policy file."""
18
19from __future__ import print_function
20
21import enum
22
23import bpf
24import parser  # pylint: disable=wrong-import-order
25
26
27class OptimizationStrategy(enum.Enum):
28    """The available optimization strategies."""
29
30    # Generate a linear chain of syscall number checks. Works best for policies
31    # with very few syscalls.
32    LINEAR = 'linear'
33
34    # Generate a binary search tree for the syscalls. Works best for policies
35    # with a lot of syscalls, where no one syscall dominates.
36    BST = 'bst'
37
38    def __str__(self):
39        return self.value
40
41
42class SyscallPolicyEntry:
43    """The parsed version of a seccomp policy line."""
44
45    def __init__(self, name, number, frequency):
46        self.name = name
47        self.number = number
48        self.frequency = frequency
49        self.accumulated = 0
50        self.filter = None
51
52    def __repr__(self):
53        return ('SyscallPolicyEntry<name: %s, number: %d, '
54                'frequency: %d, filter: %r>') % (
55                    self.name, self.number, self.frequency,
56                    self.filter.instructions if self.filter else None)
57
58    def simulate(self, arch, syscall_number, *args):
59        """Simulate the policy with the given arguments."""
60        if not self.filter:
61            return (0, 'ALLOW')
62        return bpf.simulate(self.filter.instructions, arch, syscall_number,
63                            *args)
64
65
66class SyscallPolicyRange:
67    """A contiguous range of SyscallPolicyEntries that have the same action."""
68
69    def __init__(self, *entries):
70        self.numbers = (entries[0].number, entries[-1].number + 1)
71        self.frequency = sum(e.frequency for e in entries)
72        self.accumulated = 0
73        self.filter = entries[0].filter
74
75    def __repr__(self):
76        return 'SyscallPolicyRange<numbers: %r, frequency: %d, filter: %r>' % (
77            self.numbers, self.frequency,
78            self.filter.instructions if self.filter else None)
79
80    def simulate(self, arch, syscall_number, *args):
81        """Simulate the policy with the given arguments."""
82        if not self.filter:
83            return (0, 'ALLOW')
84        return self.filter.simulate(arch, syscall_number, *args)
85
86
87def _convert_to_ranges(entries):
88    entries = list(sorted(entries, key=lambda r: r.number))
89    lower = 0
90    while lower < len(entries):
91        upper = lower + 1
92        while upper < len(entries):
93            if entries[upper - 1].filter != entries[upper].filter:
94                break
95            if entries[upper - 1].number + 1 != entries[upper].number:
96                break
97            upper += 1
98        yield SyscallPolicyRange(*entries[lower:upper])
99        lower = upper
100
101
102def _compile_single_range(entry,
103                          accept_action,
104                          reject_action,
105                          lower_bound=0,
106                          upper_bound=1e99):
107    action = accept_action
108    if entry.filter:
109        action = entry.filter
110    if entry.numbers[1] - entry.numbers[0] == 1:
111        # Single syscall.
112        # Accept if |X == nr|.
113        return (1,
114                bpf.SyscallEntry(
115                    entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ))
116    elif entry.numbers[0] == lower_bound:
117        # Syscall range aligned with the lower bound.
118        # Accept if |X < nr[1]|.
119        return (1,
120                bpf.SyscallEntry(
121                    entry.numbers[1], reject_action, action, op=bpf.BPF_JGE))
122    elif entry.numbers[1] == upper_bound:
123        # Syscall range aligned with the upper bound.
124        # Accept if |X >= nr[0]|.
125        return (1,
126                bpf.SyscallEntry(
127                    entry.numbers[0], action, reject_action, op=bpf.BPF_JGE))
128    # Syscall range in the middle.
129    # Accept if |nr[0] <= X < nr[1]|.
130    upper_entry = bpf.SyscallEntry(
131        entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
132    return (2,
133            bpf.SyscallEntry(
134                entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE))
135
136
137def _compile_ranges_linear(ranges, accept_action, reject_action):
138    # Compiles the list of ranges into a simple linear list of comparisons. In
139    # order to make the generated code a bit more efficient, we sort the
140    # ranges by frequency, so that the most frequently-called syscalls appear
141    # earlier in the chain.
142    cost = 0
143    accumulated_frequencies = 0
144    next_action = reject_action
145    for entry in sorted(ranges, key=lambda r: r.frequency):
146        current_cost, next_action = _compile_single_range(
147            entry, accept_action, next_action)
148        accumulated_frequencies += entry.frequency
149        cost += accumulated_frequencies * current_cost
150    return (cost, next_action)
151
152
153def _compile_entries_linear(entries, accept_action, reject_action):
154    return _compile_ranges_linear(
155        _convert_to_ranges(entries), accept_action, reject_action)[1]
156
157
158def _compile_entries_bst(entries, accept_action, reject_action):
159    # Instead of generating a linear list of comparisons, this method generates
160    # a binary search tree, where some of the leaves can be linear chains of
161    # comparisons.
162    #
163    # Even though we are going to perform a binary search over the syscall
164    # number, we would still like to rotate some of the internal nodes of the
165    # binary search tree so that more frequently-used syscalls can be accessed
166    # more cheaply (i.e. fewer internal nodes need to be traversed to reach
167    # them).
168    #
169    # This uses Dynamic Programming to generate all possible BSTs efficiently
170    # (in O(n^3)) so that we can get the absolute minimum-cost tree that matches
171    # all syscall entries. It does so by considering all of the O(n^2) possible
172    # sub-intervals, and for each one of those try all of the O(n) partitions of
173    # that sub-interval. At each step, it considers putting the remaining
174    # entries in a linear comparison chain as well as another BST, and chooses
175    # the option that minimizes the total overall cost.
176    #
177    # Between every pair of non-contiguous allowed syscalls, there are two
178    # locally optimal options as to where to set the partition for the
179    # subsequent ranges: aligned to the end of the left subrange or to the
180    # beginning of the right subrange. The fact that these two options have
181    # slightly different costs, combined with the possibility of a subtree to
182    # use the linear chain strategy (which has a completely different cost
183    # model), causes the target cost function that we are trying to optimize to
184    # not be unimodal / convex. This unfortunately means that more clever
185    # techniques like using ternary search (which would reduce the overall
186    # complexity to O(n^2 log n)) do not work in all cases.
187    ranges = list(_convert_to_ranges(entries))
188
189    accumulated = 0
190    for entry in ranges:
191        accumulated += entry.frequency
192        entry.accumulated = accumulated
193
194    # Memoization cache to build the DP table top-down, which is easier to
195    # understand.
196    memoized_costs = {}
197
198    def _generate_syscall_bst(ranges, indices, bounds=(0, 2**64 - 1)):
199        assert bounds[0] <= ranges[indices[0]].numbers[0], (indices, bounds)
200        assert ranges[indices[1] - 1].numbers[1] <= bounds[1], (indices,
201                                                                bounds)
202
203        if bounds in memoized_costs:
204            return memoized_costs[bounds]
205        if indices[1] - indices[0] == 1:
206            if bounds == ranges[indices[0]].numbers:
207                # If bounds are tight around the syscall, it costs nothing.
208                memoized_costs[bounds] = (0, ranges[indices[0]].filter
209                                          or accept_action)
210                return memoized_costs[bounds]
211            result = _compile_single_range(ranges[indices[0]], accept_action,
212                                           reject_action)
213            memoized_costs[bounds] = (result[0] * ranges[indices[0]].frequency,
214                                      result[1])
215            return memoized_costs[bounds]
216
217        # Try the linear model first and use that as the best estimate so far.
218        best_cost = _compile_ranges_linear(ranges[slice(*indices)],
219                                           accept_action, reject_action)
220
221        # Now recursively go through all possible partitions of the interval
222        # currently being considered.
223        previous_accumulated = ranges[indices[0]].accumulated - ranges[
224            indices[0]].frequency
225        bst_comparison_cost = (
226            ranges[indices[1] - 1].accumulated - previous_accumulated)
227        for i, entry in enumerate(ranges[slice(*indices)]):
228            candidates = [entry.numbers[0]]
229            if i:
230                candidates.append(ranges[i - 1 + indices[0]].numbers[1])
231            for cutoff_bound in candidates:
232                if not bounds[0] < cutoff_bound < bounds[1]:
233                    continue
234                if not indices[0] < i + indices[0] < indices[1]:
235                    continue
236                left_subtree = _generate_syscall_bst(
237                    ranges, (indices[0], i + indices[0]),
238                    (bounds[0], cutoff_bound))
239                right_subtree = _generate_syscall_bst(
240                    ranges, (i + indices[0], indices[1]),
241                    (cutoff_bound, bounds[1]))
242                best_cost = min(
243                    best_cost,
244                    (bst_comparison_cost + left_subtree[0] + right_subtree[0],
245                     bpf.SyscallEntry(
246                         cutoff_bound,
247                         right_subtree[1],
248                         left_subtree[1],
249                         op=bpf.BPF_JGE)))
250
251        memoized_costs[bounds] = best_cost
252        return memoized_costs[bounds]
253
254    return _generate_syscall_bst(ranges, (0, len(ranges)))[1]
255
256
257class PolicyCompiler:
258    """A parser for the Minijail seccomp policy file format."""
259
260    def __init__(self, arch):
261        self._arch = arch
262
263    def compile_file(self,
264                     policy_filename,
265                     *,
266                     optimization_strategy,
267                     kill_action,
268                     include_depth_limit=10,
269                     override_default_action=None):
270        """Return a compiled BPF program from the provided policy file."""
271        policy_parser = parser.PolicyParser(
272            self._arch,
273            kill_action=kill_action,
274            include_depth_limit=include_depth_limit,
275            override_default_action=override_default_action)
276        parsed_policy = policy_parser.parse_file(policy_filename)
277        entries = [
278            self.compile_filter_statement(
279                filter_statement, kill_action=kill_action)
280            for filter_statement in parsed_policy.filter_statements
281        ]
282
283        visitor = bpf.FlatteningVisitor(
284            arch=self._arch, kill_action=kill_action)
285        accept_action = bpf.Allow()
286        reject_action = parsed_policy.default_action
287        if entries:
288            if optimization_strategy == OptimizationStrategy.BST:
289                next_action = _compile_entries_bst(entries, accept_action,
290                                                   reject_action)
291            else:
292                next_action = _compile_entries_linear(entries, accept_action,
293                                                      reject_action)
294            next_action.accept(bpf.ArgFilterForwardingVisitor(visitor))
295            reject_action.accept(visitor)
296            accept_action.accept(visitor)
297            bpf.ValidateArch(next_action).accept(visitor)
298        else:
299            reject_action.accept(visitor)
300            bpf.ValidateArch(reject_action).accept(visitor)
301        return visitor.result
302
303    def compile_filter_statement(self, filter_statement, *, kill_action):
304        """Compile one parser.FilterStatement into BPF."""
305        policy_entry = SyscallPolicyEntry(filter_statement.syscall.name,
306                                          filter_statement.syscall.number,
307                                          filter_statement.frequency)
308        # In each step of the way, the false action is the one that is taken if
309        # the immediate boolean condition does not match. This means that the
310        # false action taken here is the one that applies if the whole
311        # expression fails to match.
312        false_action = filter_statement.filters[-1].action
313        if false_action == bpf.Allow():
314            return policy_entry
315        # We then traverse the list of filters backwards since we want
316        # the root of the DAG to be the very first boolean operation in
317        # the filter chain.
318        for filt in filter_statement.filters[:-1][::-1]:
319            for disjunction in filt.expression:
320                # This is the jump target of the very last comparison in the
321                # conjunction. Given that any conjunction that succeeds should
322                # make the whole expression succeed, make the very last
323                # comparison jump to the accept action if it succeeds.
324                true_action = filt.action
325                for atom in disjunction:
326                    block = bpf.Atom(atom.argument_index, atom.op, atom.value,
327                                     true_action, false_action)
328                    true_action = block
329                false_action = true_action
330        policy_filter = false_action
331
332        # Lower all Atoms into WideAtoms.
333        lowering_visitor = bpf.LoweringVisitor(arch=self._arch)
334        policy_filter = lowering_visitor.process(policy_filter)
335
336        # Flatten the IR DAG into a single BasicBlock.
337        flattening_visitor = bpf.FlatteningVisitor(
338            arch=self._arch, kill_action=kill_action)
339        policy_filter.accept(flattening_visitor)
340        policy_entry.filter = flattening_visitor.result
341        return policy_entry
342