1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Unittests for the compiler module."""
18
19from __future__ import print_function
20
21import os
22import random
23import shutil
24import tempfile
25import unittest
26
27import arch
28import bpf
29import compiler
30import parser  # pylint: disable=wrong-import-order
31
32ARCH_64 = arch.Arch.load_from_json(
33    os.path.join(
34        os.path.dirname(os.path.abspath(__file__)), 'testdata/arch_64.json'))
35
36
37class CompileFilterStatementTests(unittest.TestCase):
38    """Tests for PolicyCompiler.compile_filter_statement."""
39
40    def setUp(self):
41        self.arch = ARCH_64
42        self.compiler = compiler.PolicyCompiler(self.arch)
43
44    def _compile(self, line):
45        with tempfile.NamedTemporaryFile(mode='w') as policy_file:
46            policy_file.write(line)
47            policy_file.flush()
48            policy_parser = parser.PolicyParser(
49                self.arch, kill_action=bpf.KillProcess())
50            parsed_policy = policy_parser.parse_file(policy_file.name)
51            assert len(parsed_policy.filter_statements) == 1
52            return self.compiler.compile_filter_statement(
53                parsed_policy.filter_statements[0],
54                kill_action=bpf.KillProcess())
55
56    def test_allow(self):
57        """Accept lines where the syscall is accepted unconditionally."""
58        block = self._compile('read: allow')
59        self.assertEqual(block.filter, None)
60        self.assertEqual(
61            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
62                           0)[1], 'ALLOW')
63        self.assertEqual(
64            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
65                           1)[1], 'ALLOW')
66
67    def test_arg0_eq_generated_code(self):
68        """Accept lines with an argument filter with ==."""
69        block = self._compile('read: arg0 == 0x100')
70        # It might be a bit brittle to check the generated code in each test
71        # case instead of just the behavior, but there should be at least one
72        # test where this happens.
73        self.assertEqual(
74            block.filter.instructions,
75            [
76                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
77                               bpf.arg_offset(0, True)),
78                # Jump to KILL_PROCESS if the high word does not match.
79                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0),
80                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
81                               bpf.arg_offset(0, False)),
82                # Jump to KILL_PROCESS if the low word does not match.
83                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0,
84                               0x100),
85                bpf.SockFilter(bpf.BPF_RET, 0, 0,
86                               bpf.SECCOMP_RET_KILL_PROCESS),
87                bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW),
88            ])
89
90    def test_arg0_comparison_operators(self):
91        """Accept lines with an argument filter with comparison operators."""
92        biases = (-1, 0, 1)
93        # For each operator, store the expectations of simulating the program
94        # against the constant plus each entry from the |biases| array.
95        cases = (
96            ('==', ('KILL_PROCESS', 'ALLOW', 'KILL_PROCESS')),
97            ('!=', ('ALLOW', 'KILL_PROCESS', 'ALLOW')),
98            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
99            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
100            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
101            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
102        )
103        for operator, expectations in cases:
104            block = self._compile('read: arg0 %s 0x100' % operator)
105
106            # Check the filter's behavior.
107            for bias, expectation in zip(biases, expectations):
108                self.assertEqual(
109                    block.simulate(self.arch.arch_nr,
110                                   self.arch.syscalls['read'],
111                                   0x100 + bias)[1], expectation)
112
113    def test_arg0_mask_operator(self):
114        """Accept lines with an argument filter with &."""
115        block = self._compile('read: arg0 & 0x3')
116
117        self.assertEqual(
118            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
119                           0)[1], 'KILL_PROCESS')
120        self.assertEqual(
121            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
122                           1)[1], 'ALLOW')
123        self.assertEqual(
124            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
125                           2)[1], 'ALLOW')
126        self.assertEqual(
127            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
128                           3)[1], 'ALLOW')
129        self.assertEqual(
130            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
131                           4)[1], 'KILL_PROCESS')
132        self.assertEqual(
133            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
134                           5)[1], 'ALLOW')
135        self.assertEqual(
136            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
137                           6)[1], 'ALLOW')
138        self.assertEqual(
139            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
140                           7)[1], 'ALLOW')
141        self.assertEqual(
142            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
143                           8)[1], 'KILL_PROCESS')
144
145    def test_arg0_in_operator(self):
146        """Accept lines with an argument filter with in."""
147        block = self._compile('read: arg0 in 0x3')
148
149        # The 'in' operator only ensures that no bits outside the mask are set,
150        # which means that 0 is always allowed.
151        self.assertEqual(
152            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
153                           0)[1], 'ALLOW')
154        self.assertEqual(
155            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
156                           1)[1], 'ALLOW')
157        self.assertEqual(
158            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
159                           2)[1], 'ALLOW')
160        self.assertEqual(
161            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
162                           3)[1], 'ALLOW')
163        self.assertEqual(
164            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
165                           4)[1], 'KILL_PROCESS')
166        self.assertEqual(
167            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
168                           5)[1], 'KILL_PROCESS')
169        self.assertEqual(
170            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
171                           6)[1], 'KILL_PROCESS')
172        self.assertEqual(
173            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
174                           7)[1], 'KILL_PROCESS')
175        self.assertEqual(
176            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
177                           8)[1], 'KILL_PROCESS')
178
179    def test_arg0_short_gt_ge_comparisons(self):
180        """Ensure that the short comparison optimization kicks in."""
181        if self.arch.bits == 32:
182            return
183        short_constant_str = '0xdeadbeef'
184        short_constant = int(short_constant_str, base=0)
185        long_constant_str = '0xbadc0ffee0ddf00d'
186        long_constant = int(long_constant_str, base=0)
187        biases = (-1, 0, 1)
188        # For each operator, store the expectations of simulating the program
189        # against the constant plus each entry from the |biases| array.
190        cases = (
191            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
192            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
193            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
194            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
195        )
196        for operator, expectations in cases:
197            short_block = self._compile(
198                'read: arg0 %s %s' % (operator, short_constant_str))
199            long_block = self._compile(
200                'read: arg0 %s %s' % (operator, long_constant_str))
201
202            # Check that the emitted code is shorter when the high word of the
203            # constant is zero.
204            self.assertLess(
205                len(short_block.filter.instructions),
206                len(long_block.filter.instructions))
207
208            # Check the filter's behavior.
209            for bias, expectation in zip(biases, expectations):
210                self.assertEqual(
211                    long_block.simulate(self.arch.arch_nr,
212                                        self.arch.syscalls['read'],
213                                        long_constant + bias)[1], expectation)
214                self.assertEqual(
215                    short_block.simulate(
216                        self.arch.arch_nr, self.arch.syscalls['read'],
217                        short_constant + bias)[1], expectation)
218
219    def test_and_or(self):
220        """Accept lines with a complex expression in DNF."""
221        block = self._compile('read: arg0 == 0 && arg1 == 0 || arg0 == 1')
222
223        self.assertEqual(
224            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
225                           0)[1], 'ALLOW')
226        self.assertEqual(
227            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
228                           1)[1], 'KILL_PROCESS')
229        self.assertEqual(
230            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
231                           0)[1], 'ALLOW')
232        self.assertEqual(
233            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
234                           1)[1], 'ALLOW')
235
236    def test_trap(self):
237        """Accept lines that trap unconditionally."""
238        block = self._compile('read: trap')
239
240        self.assertEqual(
241            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
242                           0)[1], 'TRAP')
243
244    def test_ret_errno(self):
245        """Accept lines that return errno."""
246        block = self._compile('read : arg0 == 0 || arg0 == 1 ; return 1')
247
248        self.assertEqual(
249            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
250                           0)[1:], ('ERRNO', 1))
251        self.assertEqual(
252            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
253                           1)[1:], ('ERRNO', 1))
254        self.assertEqual(
255            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
256                           2)[1], 'KILL_PROCESS')
257
258    def test_ret_errno_unconditionally(self):
259        """Accept lines that return errno unconditionally."""
260        block = self._compile('read: return 1')
261
262        self.assertEqual(
263            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
264                           0)[1:], ('ERRNO', 1))
265
266    def test_trace(self):
267        """Accept lines that trace unconditionally."""
268        block = self._compile('read: trace')
269
270        self.assertEqual(
271            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
272                           0)[1], 'TRACE')
273
274    def test_user_notify(self):
275        """Accept lines that notify unconditionally."""
276        block = self._compile('read: user-notify')
277
278        self.assertEqual(
279            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
280                           0)[1], 'USER_NOTIF')
281
282    def test_log(self):
283        """Accept lines that log unconditionally."""
284        block = self._compile('read: log')
285
286        self.assertEqual(
287            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
288                           0)[1], 'LOG')
289
290    def test_mmap_write_xor_exec(self):
291        """Accept the idiomatic filter for mmap."""
292        block = self._compile(
293            'read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC')
294
295        prot_exec_and_write = 6
296        for prot in range(0, 0xf):
297            if (prot & prot_exec_and_write) == prot_exec_and_write:
298                self.assertEqual(
299                    block.simulate(self.arch.arch_nr,
300                                   self.arch.syscalls['read'], prot)[1],
301                    'KILL_PROCESS')
302            else:
303                self.assertEqual(
304                    block.simulate(self.arch.arch_nr,
305                                   self.arch.syscalls['read'], prot)[1],
306                    'ALLOW')
307
308
309class CompileFileTests(unittest.TestCase):
310    """Tests for PolicyCompiler.compile_file."""
311
312    def setUp(self):
313        self.arch = ARCH_64
314        self.compiler = compiler.PolicyCompiler(self.arch)
315        self.tempdir = tempfile.mkdtemp()
316
317    def tearDown(self):
318        shutil.rmtree(self.tempdir)
319
320    def _write_file(self, filename, contents):
321        """Helper to write out a file for testing."""
322        path = os.path.join(self.tempdir, filename)
323        with open(path, 'w') as outf:
324            outf.write(contents)
325        return path
326
327    def test_compile(self):
328        """Ensure compilation works with all strategies."""
329        self._write_file(
330            'test.frequency', """
331            read: 1
332            close: 10
333        """)
334        path = self._write_file(
335            'test.policy', """
336            @frequency ./test.frequency
337            read: 1
338            close: 1
339        """)
340
341        program = self.compiler.compile_file(
342            path,
343            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
344            kill_action=bpf.KillProcess())
345        self.assertGreater(
346            bpf.simulate(program.instructions, self.arch.arch_nr,
347                         self.arch.syscalls['read'], 0)[0],
348            bpf.simulate(program.instructions, self.arch.arch_nr,
349                         self.arch.syscalls['close'], 0)[0],
350        )
351
352    def test_compile_bst(self):
353        """Ensure compilation with BST is cheaper than the linear model."""
354        self._write_file(
355            'test.frequency', """
356            read: 1
357            close: 10
358        """)
359        path = self._write_file(
360            'test.policy', """
361            @frequency ./test.frequency
362            read: 1
363            close: 1
364        """)
365
366        for strategy in list(compiler.OptimizationStrategy):
367            program = self.compiler.compile_file(
368                path,
369                optimization_strategy=strategy,
370                kill_action=bpf.KillProcess())
371            self.assertGreater(
372                bpf.simulate(program.instructions, self.arch.arch_nr,
373                             self.arch.syscalls['read'], 0)[0],
374                bpf.simulate(program.instructions, self.arch.arch_nr,
375                             self.arch.syscalls['close'], 0)[0],
376            )
377            self.assertEqual(
378                bpf.simulate(program.instructions, self.arch.arch_nr,
379                             self.arch.syscalls['read'], 0)[1], 'ALLOW')
380            self.assertEqual(
381                bpf.simulate(program.instructions, self.arch.arch_nr,
382                             self.arch.syscalls['close'], 0)[1], 'ALLOW')
383
384    def test_compile_empty_file(self):
385        """Accept empty files."""
386        path = self._write_file(
387            'test.policy', """
388            @default kill-thread
389        """)
390
391        for strategy in list(compiler.OptimizationStrategy):
392            program = self.compiler.compile_file(
393                path,
394                optimization_strategy=strategy,
395                kill_action=bpf.KillProcess())
396            self.assertEqual(
397                bpf.simulate(program.instructions, self.arch.arch_nr,
398                             self.arch.syscalls['read'], 0)[1], 'KILL_THREAD')
399
400    def test_compile_simulate(self):
401        """Ensure policy reflects script by testing some random scripts."""
402        iterations = 5
403        for i in range(iterations):
404            num_entries = 64 * (i + 1) // iterations
405            syscalls = dict(
406                zip(
407                    random.sample(self.arch.syscalls.keys(), num_entries),
408                    (random.randint(1, 1024) for _ in range(num_entries)),
409                ))
410
411            frequency_contents = '\n'.join(
412                '%s: %d' % s for s in syscalls.items())
413            policy_contents = '@frequency ./test.frequency\n' + '\n'.join(
414                '%s: 1' % s[0] for s in syscalls.items())
415
416            self._write_file('test.frequency', frequency_contents)
417            path = self._write_file('test.policy', policy_contents)
418
419            for strategy in list(compiler.OptimizationStrategy):
420                program = self.compiler.compile_file(
421                    path,
422                    optimization_strategy=strategy,
423                    kill_action=bpf.KillProcess())
424                for name, number in self.arch.syscalls.items():
425                    expected_result = ('ALLOW'
426                                       if name in syscalls else 'KILL_PROCESS')
427                    self.assertEqual(
428                        bpf.simulate(program.instructions, self.arch.arch_nr,
429                                     number, 0)[1], expected_result,
430                        ('syscall name: %s, syscall number: %d, '
431                         'strategy: %s, policy:\n%s') %
432                        (name, number, strategy, policy_contents))
433
434    @unittest.skipIf(not int(os.getenv('SLOW_TESTS', '0')), 'slow')
435    def test_compile_huge_policy(self):
436        """Ensure jumps while compiling a huge policy are still valid."""
437        # Given that the BST strategy is O(n^3), don't choose a crazy large
438        # value, but it still needs to be around 128 so that we exercise the
439        # codegen paths that depend on the length of the jump.
440        #
441        # Immediate jump offsets in BPF comparison instructions are limited to
442        # 256 instructions, so given that every syscall filter consists of a
443        # load and jump instructions, with 128 syscalls there will be at least
444        # one jump that's further than 256 instructions.
445        num_entries = 128
446        syscalls = dict(random.sample(self.arch.syscalls.items(), num_entries))
447        # Here we force every single filter to be distinct. Otherwise the
448        # codegen layer will coalesce filters that compile to the same
449        # instructions.
450        policy_contents = '\n'.join(
451            '%s: arg0 == %d' % s for s in syscalls.items())
452
453        path = self._write_file('test.policy', policy_contents)
454
455        program = self.compiler.compile_file(
456            path,
457            optimization_strategy=compiler.OptimizationStrategy.BST,
458            kill_action=bpf.KillProcess())
459        for name, number in self.arch.syscalls.items():
460            expected_result = ('ALLOW'
461                               if name in syscalls else 'KILL_PROCESS')
462            self.assertEqual(
463                bpf.simulate(program.instructions, self.arch.arch_nr,
464                             self.arch.syscalls[name], number)[1],
465                expected_result)
466            self.assertEqual(
467                bpf.simulate(program.instructions, self.arch.arch_nr,
468                             self.arch.syscalls[name], number + 1)[1],
469                'KILL_PROCESS')
470
471    def test_compile_huge_filter(self):
472        """Ensure jumps while compiling a huge policy are still valid."""
473        # This is intended to force cases where the AST visitation would result
474        # in a combinatorial explosion of calls to Block.accept(). An optimized
475        # implementation should be O(n).
476        num_entries = 128
477        syscalls = {}
478        # Here we force every single filter to be distinct. Otherwise the
479        # codegen layer will coalesce filters that compile to the same
480        # instructions.
481        policy_contents = []
482        for name in random.sample(self.arch.syscalls.keys(), num_entries):
483            values = random.sample(range(1024), num_entries)
484            syscalls[name] = values
485            policy_contents.append(
486                '%s: %s' % (name, ' || '.join('arg0 == %d' % value
487                                              for value in values)))
488
489        path = self._write_file('test.policy', '\n'.join(policy_contents))
490
491        program = self.compiler.compile_file(
492            path,
493            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
494            kill_action=bpf.KillProcess())
495        for name, values in syscalls.items():
496            self.assertEqual(
497                bpf.simulate(program.instructions,
498                             self.arch.arch_nr, self.arch.syscalls[name],
499                             random.choice(values))[1], 'ALLOW')
500            self.assertEqual(
501                bpf.simulate(program.instructions, self.arch.arch_nr,
502                             self.arch.syscalls[name], 1025)[1],
503                'KILL_PROCESS')
504
505
506if __name__ == '__main__':
507    unittest.main()
508