1#!/usr/bin/env python3
2# Copyright 2016 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Unittests for the shell module."""
17
18import difflib
19import os
20import sys
21import unittest
22
23_path = os.path.realpath(__file__ + '/../..')
24if sys.path[0] != _path:
25    sys.path.insert(0, _path)
26del _path
27
28# We have to import our local modules after the sys.path tweak.  We can't use
29# relative imports because this is an executable program, not a module.
30# pylint: disable=wrong-import-position
31import rh.shell
32
33
34class DiffTestCase(unittest.TestCase):
35    """Helper that includes diff output when failing."""
36
37    def setUp(self):
38        self.differ = difflib.Differ()
39
40    def _assertEqual(self, func, test_input, test_output, result):
41        """Like assertEqual but with built in diff support."""
42        diff = '\n'.join(list(self.differ.compare([test_output], [result])))
43        msg = ('Expected %s to translate %r to %r, but got %r\n%s' %
44               (func, test_input, test_output, result, diff))
45        self.assertEqual(test_output, result, msg)
46
47    def _testData(self, functor, tests, check_type=True):
48        """Process a dict of test data."""
49        for test_output, test_input in tests.items():
50            result = functor(test_input)
51            self._assertEqual(functor.__name__, test_input, test_output, result)
52
53            if check_type:
54                # Also make sure the result is a string, otherwise the %r
55                # output will include a "u" prefix and that is not good for
56                # logging.
57                self.assertEqual(type(test_output), str)
58
59
60class ShellQuoteTest(DiffTestCase):
61    """Test the shell_quote & shell_unquote functions."""
62
63    def testShellQuote(self):
64        """Basic ShellQuote tests."""
65        # Dict of expected output strings to input lists.
66        tests_quote = {
67            "''": '',
68            'a': u'a',
69            "'a b c'": u'a b c',
70            "'a\tb'": 'a\tb',
71            "'/a$file'": '/a$file',
72            "'/a#file'": '/a#file',
73            """'b"c'""": 'b"c',
74            "'a@()b'": 'a@()b',
75            'j%k': 'j%k',
76            r'''"s'a\$va\\rs"''': r"s'a$va\rs",
77            r'''"\\'\\\""''': r'''\'\"''',
78            r'''"'\\\$"''': r"""'\$""",
79        }
80
81        # Expected input output specific to ShellUnquote.  This string cannot
82        # be produced by ShellQuote but is still a valid bash escaped string.
83        tests_unquote = {
84            r'''\$''': r'''"\\$"''',
85        }
86
87        def aux(s):
88            return rh.shell.shell_unquote(rh.shell.shell_quote(s))
89
90        self._testData(rh.shell.shell_quote, tests_quote)
91        self._testData(rh.shell.shell_unquote, tests_unquote)
92
93        # Test that the operations are reversible.
94        self._testData(aux, {k: k for k in tests_quote.values()}, False)
95        self._testData(aux, {k: k for k in tests_quote}, False)
96
97
98class CmdToStrTest(DiffTestCase):
99    """Test the cmd_to_str function."""
100
101    def testCmdToStr(self):
102        # Dict of expected output strings to input lists.
103        tests = {
104            r"a b": ['a', 'b'],
105            r"'a b' c": ['a b', 'c'],
106            r'''a "b'c"''': ['a', "b'c"],
107            r'''a "/'\$b" 'a b c' "xy'z"''':
108                [u'a', "/'$b", 'a b c', "xy'z"],
109            '': [],
110        }
111        self._testData(rh.shell.cmd_to_str, tests)
112
113
114class BooleanShellTest(unittest.TestCase):
115    """Test the boolean_shell_value function."""
116
117    def testFull(self):
118        """Verify nputs work as expected"""
119        for v in (None,):
120            self.assertTrue(rh.shell.boolean_shell_value(v, True))
121            self.assertFalse(rh.shell.boolean_shell_value(v, False))
122
123        for v in (1234, '', 'akldjsf', '"'):
124            self.assertRaises(ValueError, rh.shell.boolean_shell_value, v, True)
125
126        for v in ('yes', 'YES', 'YeS', 'y', 'Y', '1', 'true', 'True', 'TRUE',):
127            self.assertTrue(rh.shell.boolean_shell_value(v, True))
128            self.assertTrue(rh.shell.boolean_shell_value(v, False))
129
130        for v in ('no', 'NO', 'nO', 'n', 'N', '0', 'false', 'False', 'FALSE',):
131            self.assertFalse(rh.shell.boolean_shell_value(v, True))
132            self.assertFalse(rh.shell.boolean_shell_value(v, False))
133
134
135if __name__ == '__main__':
136    unittest.main()
137