1#!/usr/bin/env python
2#
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Unittests for cli_translator."""
18
19import unittest
20import os
21import re
22import mock
23
24import atest_error
25import cli_translator as cli_t
26import constants
27import test_finder_handler
28import test_mapping
29import unittest_constants as uc
30import unittest_utils
31from test_finders import test_finder_base
32
33# TEST_MAPPING related consts
34TEST_MAPPING_DIR = os.path.join(uc.TEST_DATA_DIR, 'test_mapping', 'folder1')
35
36SEARCH_DIR_RE = re.compile(r'^find ([^ ]*).*$')
37
38
39#pylint: disable=unused-argument
40def gettestinfos_side_effect(test_names, test_mapping_test_details=None):
41    """Mock return values for _get_test_info."""
42    test_infos = set()
43    for test_name in test_names:
44        if test_name == uc.MODULE_NAME:
45            test_infos.add(uc.MODULE_INFO)
46        if test_name == uc.CLASS_NAME:
47            test_infos.add(uc.CLASS_INFO)
48    return test_infos
49
50
51#pylint: disable=protected-access
52#pylint: disable=no-self-use
53class CLITranslatorUnittests(unittest.TestCase):
54    """Unit tests for cli_t.py"""
55
56    def setUp(self):
57        """Run before execution of every test"""
58        self.ctr = cli_t.CLITranslator()
59
60    @mock.patch.object(test_finder_handler, 'get_find_methods_for_test')
61    def test_get_test_infos(self, mock_getfindmethods):
62        """Test _get_test_infos method."""
63        ctr = cli_t.CLITranslator()
64        find_method_return_module_info = lambda x, y: uc.MODULE_INFO
65        # pylint: disable=invalid-name
66        find_method_return_module_class_info = (lambda x, test: uc.MODULE_INFO
67                                                if test == uc.MODULE_NAME
68                                                else uc.CLASS_INFO)
69        find_method_return_nothing = lambda x, y: None
70        one_test = [uc.MODULE_NAME]
71        mult_test = [uc.MODULE_NAME, uc.CLASS_NAME]
72
73        # Let's make sure we return what we expect.
74        expected_test_infos = {uc.MODULE_INFO}
75        mock_getfindmethods.return_value = [
76            test_finder_base.Finder(None, find_method_return_module_info)]
77        unittest_utils.assert_strict_equal(
78            self, ctr._get_test_infos(one_test), expected_test_infos)
79
80        # Check we receive multiple test infos.
81        expected_test_infos = {uc.MODULE_INFO, uc.CLASS_INFO}
82        mock_getfindmethods.return_value = [
83            test_finder_base.Finder(None, find_method_return_module_class_info)]
84        unittest_utils.assert_strict_equal(
85            self, ctr._get_test_infos(mult_test), expected_test_infos)
86
87        # Let's make sure we raise an error when we have no tests found.
88        mock_getfindmethods.return_value = [
89            test_finder_base.Finder(None, find_method_return_nothing)]
90        self.assertRaises(atest_error.NoTestFoundError, ctr._get_test_infos,
91                          one_test)
92
93        # Check the method works for test mapping.
94        test_detail1 = test_mapping.TestDetail(uc.TEST_MAPPING_TEST)
95        test_detail2 = test_mapping.TestDetail(uc.TEST_MAPPING_TEST_WITH_OPTION)
96        expected_test_infos = {uc.MODULE_INFO, uc.CLASS_INFO}
97        mock_getfindmethods.return_value = [
98            test_finder_base.Finder(None, find_method_return_module_class_info)]
99        test_infos = ctr._get_test_infos(
100            mult_test, [test_detail1, test_detail2])
101        unittest_utils.assert_strict_equal(
102            self, test_infos, expected_test_infos)
103        for test_info in test_infos:
104            if test_info == uc.MODULE_INFO:
105                self.assertEqual(
106                    test_detail1.options,
107                    test_info.data[constants.TI_MODULE_ARG])
108            else:
109                self.assertEqual(
110                    test_detail2.options,
111                    test_info.data[constants.TI_MODULE_ARG])
112
113    @mock.patch.object(cli_t.CLITranslator, '_find_tests_by_test_mapping')
114    @mock.patch.object(cli_t.CLITranslator, '_get_test_infos',
115                       side_effect=gettestinfos_side_effect)
116    #pylint: disable=unused-argument
117    def test_translate(self, _info, mock_testmapping):
118        """Test translate method."""
119        # Check that we can find a class.
120        targets, test_infos = self.ctr.translate([uc.CLASS_NAME])
121        unittest_utils.assert_strict_equal(
122            self, targets, uc.CLASS_BUILD_TARGETS)
123        unittest_utils.assert_strict_equal(self, test_infos, {uc.CLASS_INFO})
124
125        # Check that we get all the build targets we expect.
126        targets, test_infos = self.ctr.translate([uc.MODULE_NAME,
127                                                  uc.CLASS_NAME])
128        unittest_utils.assert_strict_equal(
129            self, targets, uc.MODULE_CLASS_COMBINED_BUILD_TARGETS)
130        unittest_utils.assert_strict_equal(self, test_infos, {uc.MODULE_INFO,
131                                                              uc.CLASS_INFO})
132
133        # Check that test mappings feeds into get_test_info properly.
134        test_detail1 = test_mapping.TestDetail(uc.TEST_MAPPING_TEST)
135        test_detail2 = test_mapping.TestDetail(uc.TEST_MAPPING_TEST_WITH_OPTION)
136        mock_testmapping.return_value = ([test_detail1, test_detail2], None)
137        targets, test_infos = self.ctr.translate([])
138        unittest_utils.assert_strict_equal(
139            self, targets, uc.MODULE_CLASS_COMBINED_BUILD_TARGETS)
140        unittest_utils.assert_strict_equal(self, test_infos, {uc.MODULE_INFO,
141                                                              uc.CLASS_INFO})
142
143    def test_find_tests_by_test_mapping(self):
144        """Test _find_tests_by_test_mapping method."""
145        tests, all_tests = self.ctr._find_tests_by_test_mapping(
146            path=TEST_MAPPING_DIR, file_name='test_mapping_sample')
147        expected = set(['test2', 'test1'])
148        expected_all_tests = {'presubmit': expected,
149                              'postsubmit': set(['test3'])}
150        self.assertEqual(expected, tests)
151        self.assertEqual(expected_all_tests, all_tests)
152
153        tests, all_tests = self.ctr._find_tests_by_test_mapping(
154            path=TEST_MAPPING_DIR, test_group=constants.TEST_GROUP_POSTSUBMIT,
155            file_name='test_mapping_sample')
156        expected = set(['test1', 'test2', 'test3'])
157        self.assertEqual(expected, tests)
158        self.assertEqual(expected_all_tests, all_tests)
159
160if __name__ == '__main__':
161    unittest.main()
162