1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""A BPF compiler for the Minijail policy file."""
18
19from __future__ import print_function
20
21import enum
22
23try:
24    import bpf
25    import parser  # pylint: disable=wrong-import-order
26except ImportError:
27    from minijail import bpf
28    from minijail import parser  # pylint: disable=wrong-import-order
29
30
31class OptimizationStrategy(enum.Enum):
32    """The available optimization strategies."""
33
34    # Generate a linear chain of syscall number checks. Works best for policies
35    # with very few syscalls.
36    LINEAR = 'linear'
37
38    # Generate a binary search tree for the syscalls. Works best for policies
39    # with a lot of syscalls, where no one syscall dominates.
40    BST = 'bst'
41
42    def __str__(self):
43        return self.value
44
45
46class SyscallPolicyEntry:
47    """The parsed version of a seccomp policy line."""
48
49    def __init__(self, name, number, frequency):
50        self.name = name
51        self.number = number
52        self.frequency = frequency
53        self.accumulated = 0
54        self.filter = None
55
56    def __repr__(self):
57        return ('SyscallPolicyEntry<name: %s, number: %d, '
58                'frequency: %d, filter: %r>') % (
59                    self.name, self.number, self.frequency,
60                    self.filter.instructions if self.filter else None)
61
62    def simulate(self, arch, syscall_number, *args):
63        """Simulate the policy with the given arguments."""
64        if not self.filter:
65            return (0, 'ALLOW')
66        return bpf.simulate(self.filter.instructions, arch, syscall_number,
67                            *args)
68
69
70class SyscallPolicyRange:
71    """A contiguous range of SyscallPolicyEntries that have the same action."""
72
73    def __init__(self, *entries):
74        self.numbers = (entries[0].number, entries[-1].number + 1)
75        self.frequency = sum(e.frequency for e in entries)
76        self.accumulated = 0
77        self.filter = entries[0].filter
78
79    def __repr__(self):
80        return 'SyscallPolicyRange<numbers: %r, frequency: %d, filter: %r>' % (
81            self.numbers, self.frequency,
82            self.filter.instructions if self.filter else None)
83
84    def simulate(self, arch, syscall_number, *args):
85        """Simulate the policy with the given arguments."""
86        if not self.filter:
87            return (0, 'ALLOW')
88        return self.filter.simulate(arch, syscall_number, *args)
89
90
91def _convert_to_ranges(entries):
92    entries = list(sorted(entries, key=lambda r: r.number))
93    lower = 0
94    while lower < len(entries):
95        upper = lower + 1
96        while upper < len(entries):
97            if entries[upper - 1].filter != entries[upper].filter:
98                break
99            if entries[upper - 1].number + 1 != entries[upper].number:
100                break
101            upper += 1
102        yield SyscallPolicyRange(*entries[lower:upper])
103        lower = upper
104
105
106def _compile_single_range(entry,
107                          accept_action,
108                          reject_action,
109                          lower_bound=0,
110                          upper_bound=1e99):
111    action = accept_action
112    if entry.filter:
113        action = entry.filter
114    if entry.numbers[1] - entry.numbers[0] == 1:
115        # Single syscall.
116        # Accept if |X == nr|.
117        return (1,
118                bpf.SyscallEntry(
119                    entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ))
120    elif entry.numbers[0] == lower_bound:
121        # Syscall range aligned with the lower bound.
122        # Accept if |X < nr[1]|.
123        return (1,
124                bpf.SyscallEntry(
125                    entry.numbers[1], reject_action, action, op=bpf.BPF_JGE))
126    elif entry.numbers[1] == upper_bound:
127        # Syscall range aligned with the upper bound.
128        # Accept if |X >= nr[0]|.
129        return (1,
130                bpf.SyscallEntry(
131                    entry.numbers[0], action, reject_action, op=bpf.BPF_JGE))
132    # Syscall range in the middle.
133    # Accept if |nr[0] <= X < nr[1]|.
134    upper_entry = bpf.SyscallEntry(
135        entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
136    return (2,
137            bpf.SyscallEntry(
138                entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE))
139
140
141def _compile_ranges_linear(ranges, accept_action, reject_action):
142    # Compiles the list of ranges into a simple linear list of comparisons. In
143    # order to make the generated code a bit more efficient, we sort the
144    # ranges by frequency, so that the most frequently-called syscalls appear
145    # earlier in the chain.
146    cost = 0
147    accumulated_frequencies = 0
148    next_action = reject_action
149    for entry in sorted(ranges, key=lambda r: r.frequency):
150        current_cost, next_action = _compile_single_range(
151            entry, accept_action, next_action)
152        accumulated_frequencies += entry.frequency
153        cost += accumulated_frequencies * current_cost
154    return (cost, next_action)
155
156
157def _compile_entries_linear(entries, accept_action, reject_action):
158    return _compile_ranges_linear(
159        _convert_to_ranges(entries), accept_action, reject_action)[1]
160
161
162def _compile_entries_bst(entries, accept_action, reject_action):
163    # Instead of generating a linear list of comparisons, this method generates
164    # a binary search tree, where some of the leaves can be linear chains of
165    # comparisons.
166    #
167    # Even though we are going to perform a binary search over the syscall
168    # number, we would still like to rotate some of the internal nodes of the
169    # binary search tree so that more frequently-used syscalls can be accessed
170    # more cheaply (i.e. fewer internal nodes need to be traversed to reach
171    # them).
172    #
173    # This uses Dynamic Programming to generate all possible BSTs efficiently
174    # (in O(n^3)) so that we can get the absolute minimum-cost tree that matches
175    # all syscall entries. It does so by considering all of the O(n^2) possible
176    # sub-intervals, and for each one of those try all of the O(n) partitions of
177    # that sub-interval. At each step, it considers putting the remaining
178    # entries in a linear comparison chain as well as another BST, and chooses
179    # the option that minimizes the total overall cost.
180    #
181    # Between every pair of non-contiguous allowed syscalls, there are two
182    # locally optimal options as to where to set the partition for the
183    # subsequent ranges: aligned to the end of the left subrange or to the
184    # beginning of the right subrange. The fact that these two options have
185    # slightly different costs, combined with the possibility of a subtree to
186    # use the linear chain strategy (which has a completely different cost
187    # model), causes the target cost function that we are trying to optimize to
188    # not be unimodal / convex. This unfortunately means that more clever
189    # techniques like using ternary search (which would reduce the overall
190    # complexity to O(n^2 log n)) do not work in all cases.
191    ranges = list(_convert_to_ranges(entries))
192
193    accumulated = 0
194    for entry in ranges:
195        accumulated += entry.frequency
196        entry.accumulated = accumulated
197
198    # Memoization cache to build the DP table top-down, which is easier to
199    # understand.
200    memoized_costs = {}
201
202    def _generate_syscall_bst(ranges, indices, bounds=(0, 2**64 - 1)):
203        assert bounds[0] <= ranges[indices[0]].numbers[0], (indices, bounds)
204        assert ranges[indices[1] - 1].numbers[1] <= bounds[1], (indices,
205                                                                bounds)
206
207        if bounds in memoized_costs:
208            return memoized_costs[bounds]
209        if indices[1] - indices[0] == 1:
210            if bounds == ranges[indices[0]].numbers:
211                # If bounds are tight around the syscall, it costs nothing.
212                memoized_costs[bounds] = (0, ranges[indices[0]].filter
213                                          or accept_action)
214                return memoized_costs[bounds]
215            result = _compile_single_range(ranges[indices[0]], accept_action,
216                                           reject_action)
217            memoized_costs[bounds] = (result[0] * ranges[indices[0]].frequency,
218                                      result[1])
219            return memoized_costs[bounds]
220
221        # Try the linear model first and use that as the best estimate so far.
222        best_cost = _compile_ranges_linear(ranges[slice(*indices)],
223                                           accept_action, reject_action)
224
225        # Now recursively go through all possible partitions of the interval
226        # currently being considered.
227        previous_accumulated = ranges[indices[0]].accumulated - ranges[
228            indices[0]].frequency
229        bst_comparison_cost = (
230            ranges[indices[1] - 1].accumulated - previous_accumulated)
231        for i, entry in enumerate(ranges[slice(*indices)]):
232            candidates = [entry.numbers[0]]
233            if i:
234                candidates.append(ranges[i - 1 + indices[0]].numbers[1])
235            for cutoff_bound in candidates:
236                if not bounds[0] < cutoff_bound < bounds[1]:
237                    continue
238                if not indices[0] < i + indices[0] < indices[1]:
239                    continue
240                left_subtree = _generate_syscall_bst(
241                    ranges, (indices[0], i + indices[0]),
242                    (bounds[0], cutoff_bound))
243                right_subtree = _generate_syscall_bst(
244                    ranges, (i + indices[0], indices[1]),
245                    (cutoff_bound, bounds[1]))
246                best_cost = min(
247                    best_cost,
248                    (bst_comparison_cost + left_subtree[0] + right_subtree[0],
249                     bpf.SyscallEntry(
250                         cutoff_bound,
251                         right_subtree[1],
252                         left_subtree[1],
253                         op=bpf.BPF_JGE)))
254
255        memoized_costs[bounds] = best_cost
256        return memoized_costs[bounds]
257
258    return _generate_syscall_bst(ranges, (0, len(ranges)))[1]
259
260
261class PolicyCompiler:
262    """A parser for the Minijail seccomp policy file format."""
263
264    def __init__(self, arch):
265        self._arch = arch
266
267    def compile_file(self,
268                     policy_filename,
269                     *,
270                     optimization_strategy,
271                     kill_action,
272                     include_depth_limit=10,
273                     override_default_action=None):
274        """Return a compiled BPF program from the provided policy file."""
275        policy_parser = parser.PolicyParser(
276            self._arch,
277            kill_action=kill_action,
278            include_depth_limit=include_depth_limit,
279            override_default_action=override_default_action)
280        parsed_policy = policy_parser.parse_file(policy_filename)
281        entries = [
282            self.compile_filter_statement(
283                filter_statement, kill_action=kill_action)
284            for filter_statement in parsed_policy.filter_statements
285        ]
286
287        visitor = bpf.FlatteningVisitor(
288            arch=self._arch, kill_action=kill_action)
289        accept_action = bpf.Allow()
290        reject_action = parsed_policy.default_action
291        if entries:
292            if optimization_strategy == OptimizationStrategy.BST:
293                next_action = _compile_entries_bst(entries, accept_action,
294                                                   reject_action)
295            else:
296                next_action = _compile_entries_linear(entries, accept_action,
297                                                      reject_action)
298            next_action.accept(bpf.ArgFilterForwardingVisitor(visitor))
299            reject_action.accept(visitor)
300            accept_action.accept(visitor)
301            bpf.ValidateArch(next_action).accept(visitor)
302        else:
303            reject_action.accept(visitor)
304            bpf.ValidateArch(reject_action).accept(visitor)
305        return visitor.result
306
307    def compile_filter_statement(self, filter_statement, *, kill_action):
308        """Compile one parser.FilterStatement into BPF."""
309        policy_entry = SyscallPolicyEntry(filter_statement.syscall.name,
310                                          filter_statement.syscall.number,
311                                          filter_statement.frequency)
312        # In each step of the way, the false action is the one that is taken if
313        # the immediate boolean condition does not match. This means that the
314        # false action taken here is the one that applies if the whole
315        # expression fails to match.
316        false_action = filter_statement.filters[-1].action
317        if false_action == bpf.Allow():
318            return policy_entry
319        # We then traverse the list of filters backwards since we want
320        # the root of the DAG to be the very first boolean operation in
321        # the filter chain.
322        for filt in filter_statement.filters[:-1][::-1]:
323            for disjunction in filt.expression:
324                # This is the jump target of the very last comparison in the
325                # conjunction. Given that any conjunction that succeeds should
326                # make the whole expression succeed, make the very last
327                # comparison jump to the accept action if it succeeds.
328                true_action = filt.action
329                for atom in disjunction:
330                    block = bpf.Atom(atom.argument_index, atom.op, atom.value,
331                                     true_action, false_action)
332                    true_action = block
333                false_action = true_action
334        policy_filter = false_action
335
336        # Lower all Atoms into WideAtoms.
337        lowering_visitor = bpf.LoweringVisitor(arch=self._arch)
338        policy_filter = lowering_visitor.process(policy_filter)
339
340        # Flatten the IR DAG into a single BasicBlock.
341        flattening_visitor = bpf.FlatteningVisitor(
342            arch=self._arch, kill_action=kill_action)
343        policy_filter.accept(flattening_visitor)
344        policy_entry.filter = flattening_visitor.result
345        return policy_entry
346