1from __future__ import print_function
2import filecmp
3import glob
4import itertools
5import os
6import sys
7import sysconfig
8import tempfile
9import unittest
10
11
12project_dir = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
13src_dir = os.path.join(project_dir, 'python')
14test_dir = os.path.join(project_dir, 'tests')
15
16python_exe = sys.executable or 'python'
17bro_path = os.path.join(src_dir, 'bro.py')
18BRO_ARGS = [python_exe, bro_path]
19
20# Get the platform/version-specific build folder.
21# By default, the distutils build base is in the same location as setup.py.
22platform_lib_name = 'lib.{platform}-{version[0]}.{version[1]}'.format(
23    platform=sysconfig.get_platform(), version=sys.version_info)
24build_dir = os.path.join(project_dir, 'bin', platform_lib_name)
25
26# Prepend the build folder to sys.path and the PYTHONPATH environment variable.
27if build_dir not in sys.path:
28    sys.path.insert(0, build_dir)
29TEST_ENV = os.environ.copy()
30if 'PYTHONPATH' not in TEST_ENV:
31    TEST_ENV['PYTHONPATH'] = build_dir
32else:
33    TEST_ENV['PYTHONPATH'] = build_dir + os.pathsep + TEST_ENV['PYTHONPATH']
34
35TESTDATA_DIR = os.path.join(test_dir, 'testdata')
36
37TESTDATA_FILES = [
38    'empty',  # Empty file
39    '10x10y',  # Small text
40    'alice29.txt',  # Large text
41    'random_org_10k.bin',  # Small data
42    'mapsdatazrh',  # Large data
43]
44
45TESTDATA_PATHS = [os.path.join(TESTDATA_DIR, f) for f in TESTDATA_FILES]
46
47TESTDATA_PATHS_FOR_DECOMPRESSION = glob.glob(
48    os.path.join(TESTDATA_DIR, '*.compressed'))
49
50TEMP_DIR = tempfile.mkdtemp()
51
52
53def get_temp_compressed_name(filename):
54    return os.path.join(TEMP_DIR, os.path.basename(filename + '.bro'))
55
56
57def get_temp_uncompressed_name(filename):
58    return os.path.join(TEMP_DIR, os.path.basename(filename + '.unbro'))
59
60
61def bind_method_args(method, *args, **kwargs):
62    return lambda self: method(self, *args, **kwargs)
63
64
65def generate_test_methods(test_case_class,
66                          for_decompression=False,
67                          variants=None):
68    # Add test methods for each test data file.  This makes identifying problems
69    # with specific compression scenarios easier.
70    if for_decompression:
71        paths = TESTDATA_PATHS_FOR_DECOMPRESSION
72    else:
73        paths = TESTDATA_PATHS
74    opts = []
75    if variants:
76        opts_list = []
77        for k, v in variants.items():
78            opts_list.append([r for r in itertools.product([k], v)])
79        for o in itertools.product(*opts_list):
80            opts_name = '_'.join([str(i) for i in itertools.chain(*o)])
81            opts_dict = dict(o)
82            opts.append([opts_name, opts_dict])
83    else:
84        opts.append(['', {}])
85    for method in [m for m in dir(test_case_class) if m.startswith('_test')]:
86        for testdata in paths:
87            for (opts_name, opts_dict) in opts:
88                f = os.path.splitext(os.path.basename(testdata))[0]
89                name = 'test_{method}_{options}_{file}'.format(
90                    method=method, options=opts_name, file=f)
91                func = bind_method_args(
92                    getattr(test_case_class, method), testdata, **opts_dict)
93                setattr(test_case_class, name, func)
94
95
96class TestCase(unittest.TestCase):
97
98    def tearDown(self):
99        for f in TESTDATA_PATHS:
100            try:
101                os.unlink(get_temp_compressed_name(f))
102            except OSError:
103                pass
104            try:
105                os.unlink(get_temp_uncompressed_name(f))
106            except OSError:
107                pass
108
109    def assertFilesMatch(self, first, second):
110        self.assertTrue(
111            filecmp.cmp(first, second, shallow=False),
112            'File {} differs from {}'.format(first, second))
113