1#!/usr/bin/env python3
2#
3# Copyright 2018, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Unittests for atest_tf_test_runner."""
18
19# pylint: disable=line-too-long
20
21import os
22import sys
23import tempfile
24import unittest
25import json
26
27from io import StringIO
28from unittest import mock
29
30import constants
31import unittest_constants as uc
32import unittest_utils
33
34from test_finders import test_info
35from test_runners import event_handler
36from test_runners import atest_tf_test_runner as atf_tr
37
38#pylint: disable=protected-access
39#pylint: disable=invalid-name
40TEST_INFO_DIR = '/tmp/atest_run_1510085893_pi_Nbi'
41METRICS_DIR = '%s/baseline-metrics' % TEST_INFO_DIR
42METRICS_DIR_ARG = '--metrics-folder %s ' % METRICS_DIR
43# TODO(147567606): Replace {serial} with {extra_args} for general extra
44# arguments testing.
45RUN_CMD_ARGS = '{metrics}--log-level WARN{serial}'
46LOG_ARGS = atf_tr.AtestTradefedTestRunner._LOG_ARGS.format(
47    log_path=os.path.join(TEST_INFO_DIR, atf_tr.LOG_FOLDER_NAME))
48RUN_CMD = atf_tr.AtestTradefedTestRunner._RUN_CMD.format(
49    exe=atf_tr.AtestTradefedTestRunner.EXECUTABLE,
50    template=atf_tr.AtestTradefedTestRunner._TF_TEMPLATE,
51    tf_customize_template='{tf_customize_template}',
52    args=RUN_CMD_ARGS,
53    log_args=LOG_ARGS)
54FULL_CLASS2_NAME = 'android.jank.cts.ui.SomeOtherClass'
55CLASS2_FILTER = test_info.TestFilter(FULL_CLASS2_NAME, frozenset())
56METHOD2_FILTER = test_info.TestFilter(uc.FULL_CLASS_NAME, frozenset([uc.METHOD2_NAME]))
57MODULE_ARG1 = [(constants.TF_INCLUDE_FILTER_OPTION, "A"),
58               (constants.TF_INCLUDE_FILTER_OPTION, "B")]
59MODULE_ARG2 = []
60CLASS2_METHOD_FILTER = test_info.TestFilter(FULL_CLASS2_NAME,
61                                            frozenset([uc.METHOD_NAME, uc.METHOD2_NAME]))
62MODULE2_INFO = test_info.TestInfo(uc.MODULE2_NAME,
63                                  atf_tr.AtestTradefedTestRunner.NAME,
64                                  set(),
65                                  data={constants.TI_REL_CONFIG: uc.CONFIG2_FILE,
66                                        constants.TI_FILTER: frozenset()})
67CLASS1_BUILD_TARGETS = {'class_1_build_target'}
68CLASS1_INFO = test_info.TestInfo(uc.MODULE_NAME,
69                                 atf_tr.AtestTradefedTestRunner.NAME,
70                                 CLASS1_BUILD_TARGETS,
71                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
72                                       constants.TI_FILTER: frozenset([uc.CLASS_FILTER])})
73CLASS2_BUILD_TARGETS = {'class_2_build_target'}
74CLASS2_INFO = test_info.TestInfo(uc.MODULE_NAME,
75                                 atf_tr.AtestTradefedTestRunner.NAME,
76                                 CLASS2_BUILD_TARGETS,
77                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
78                                       constants.TI_FILTER: frozenset([CLASS2_FILTER])})
79CLASS3_BUILD_TARGETS = {'class_3_build_target'}
80CLASS3_INFO = test_info.TestInfo(uc.MODULE_NAME,
81                                 atf_tr.AtestTradefedTestRunner.NAME,
82                                 CLASS3_BUILD_TARGETS,
83                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
84                                       constants.TI_FILTER: frozenset(),
85                                       constants.TI_MODULE_ARG: MODULE_ARG1})
86CLASS4_BUILD_TARGETS = {'class_4_build_target'}
87CLASS4_INFO = test_info.TestInfo(uc.MODULE_NAME,
88                                 atf_tr.AtestTradefedTestRunner.NAME,
89                                 CLASS4_BUILD_TARGETS,
90                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
91                                       constants.TI_FILTER: frozenset(),
92                                       constants.TI_MODULE_ARG: MODULE_ARG2})
93CLASS1_CLASS2_MODULE_INFO = test_info.TestInfo(
94    uc.MODULE_NAME,
95    atf_tr.AtestTradefedTestRunner.NAME,
96    uc.MODULE_BUILD_TARGETS | CLASS1_BUILD_TARGETS | CLASS2_BUILD_TARGETS,
97    uc.MODULE_DATA)
98FLAT_CLASS_INFO = test_info.TestInfo(
99    uc.MODULE_NAME,
100    atf_tr.AtestTradefedTestRunner.NAME,
101    CLASS1_BUILD_TARGETS | CLASS2_BUILD_TARGETS,
102    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
103          constants.TI_FILTER: frozenset([uc.CLASS_FILTER, CLASS2_FILTER])})
104FLAT2_CLASS_INFO = test_info.TestInfo(
105    uc.MODULE_NAME,
106    atf_tr.AtestTradefedTestRunner.NAME,
107    CLASS3_BUILD_TARGETS | CLASS4_BUILD_TARGETS,
108    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
109          constants.TI_FILTER: frozenset(),
110          constants.TI_MODULE_ARG: MODULE_ARG1 + MODULE_ARG2})
111GTF_INT_CONFIG = os.path.join(uc.GTF_INT_DIR, uc.GTF_INT_NAME + '.xml')
112CLASS2_METHOD_INFO = test_info.TestInfo(
113    uc.MODULE_NAME,
114    atf_tr.AtestTradefedTestRunner.NAME,
115    set(),
116    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
117          constants.TI_FILTER:
118              frozenset([test_info.TestFilter(
119                  FULL_CLASS2_NAME, frozenset([uc.METHOD_NAME, uc.METHOD2_NAME]))])})
120METHOD_AND_CLASS2_METHOD = test_info.TestInfo(
121    uc.MODULE_NAME,
122    atf_tr.AtestTradefedTestRunner.NAME,
123    uc.MODULE_BUILD_TARGETS,
124    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
125          constants.TI_FILTER: frozenset([uc.METHOD_FILTER, CLASS2_METHOD_FILTER])})
126METHOD_METHOD2_AND_CLASS2_METHOD = test_info.TestInfo(
127    uc.MODULE_NAME,
128    atf_tr.AtestTradefedTestRunner.NAME,
129    uc.MODULE_BUILD_TARGETS,
130    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
131          constants.TI_FILTER: frozenset([uc.FLAT_METHOD_FILTER, CLASS2_METHOD_FILTER])})
132METHOD2_INFO = test_info.TestInfo(
133    uc.MODULE_NAME,
134    atf_tr.AtestTradefedTestRunner.NAME,
135    set(),
136    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
137          constants.TI_FILTER: frozenset([METHOD2_FILTER])})
138
139INT_INFO = test_info.TestInfo(
140    uc.INT_NAME,
141    atf_tr.AtestTradefedTestRunner.NAME,
142    set(),
143    test_finder='INTEGRATION')
144
145MOD_INFO = test_info.TestInfo(
146    uc.MODULE_NAME,
147    atf_tr.AtestTradefedTestRunner.NAME,
148    set(),
149    test_finder='MODULE')
150
151MOD_INFO_NO_TEST_FINDER = test_info.TestInfo(
152    uc.MODULE_NAME,
153    atf_tr.AtestTradefedTestRunner.NAME,
154    set())
155
156EVENTS_NORMAL = [
157    ('TEST_MODULE_STARTED', {
158        'moduleContextFileName':'serial-util1146216{974}2772610436.ser',
159        'moduleName':'someTestModule'}),
160    ('TEST_RUN_STARTED', {'testCount': 2}),
161    ('TEST_STARTED', {'start_time':52, 'className':'someClassName',
162                      'testName':'someTestName'}),
163    ('TEST_ENDED', {'end_time':1048, 'className':'someClassName',
164                    'testName':'someTestName'}),
165    ('TEST_STARTED', {'start_time':48, 'className':'someClassName2',
166                      'testName':'someTestName2'}),
167    ('TEST_FAILED', {'className':'someClassName2', 'testName':'someTestName2',
168                     'trace': 'someTrace'}),
169    ('TEST_ENDED', {'end_time':9876450, 'className':'someClassName2',
170                    'testName':'someTestName2'}),
171    ('TEST_RUN_ENDED', {}),
172    ('TEST_MODULE_ENDED', {'foo': 'bar'}),
173]
174
175class AtestTradefedTestRunnerUnittests(unittest.TestCase):
176    """Unit tests for atest_tf_test_runner.py"""
177
178    @mock.patch.dict('os.environ', {constants.ANDROID_BUILD_TOP:'/'})
179    def setUp(self):
180        self.tr = atf_tr.AtestTradefedTestRunner(results_dir=TEST_INFO_DIR)
181
182    def tearDown(self):
183        mock.patch.stopall()
184
185    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
186                       '_start_socket_server')
187    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
188                       'run')
189    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
190                       '_create_test_args', return_value=['some_args'])
191    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
192                       'generate_run_commands', return_value='some_cmd')
193    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
194                       '_process_connection', return_value=None)
195    @mock.patch('select.select')
196    @mock.patch('os.killpg', return_value=None)
197    @mock.patch('os.getpgid', return_value=None)
198    @mock.patch('signal.signal', return_value=None)
199    def test_run_tests_pretty(self, _signal, _pgid, _killpg, mock_select,
200                              _process, _run_cmd, _test_args,
201                              mock_run, mock_start_socket_server):
202        """Test _run_tests_pretty method."""
203        mock_subproc = mock.Mock()
204        mock_run.return_value = mock_subproc
205        mock_subproc.returncode = 0
206        mock_subproc.poll.side_effect = [True, True, None]
207        mock_server = mock.Mock()
208        mock_server.getsockname.return_value = ('', '')
209        mock_start_socket_server.return_value = mock_server
210        mock_reporter = mock.Mock()
211
212        # Test no early TF exit
213        mock_conn = mock.Mock()
214        mock_server.accept.return_value = (mock_conn, 'some_addr')
215        mock_server.close.return_value = True
216        mock_select.side_effect = [([mock_server], None, None),
217                                   ([mock_conn], None, None)]
218        self.tr.run_tests_pretty([MODULE2_INFO], {}, mock_reporter)
219
220        # Test early TF exit
221        tmp_file = tempfile.NamedTemporaryFile()
222        with open(tmp_file.name, 'w') as f:
223            f.write("tf msg")
224        self.tr.test_log_file = tmp_file
225        mock_select.side_effect = [([], None, None)]
226        mock_subproc.poll.side_effect = None
227        capture_output = StringIO()
228        sys.stdout = capture_output
229        self.assertRaises(atf_tr.TradeFedExitError, self.tr.run_tests_pretty,
230                          [MODULE2_INFO], {}, mock_reporter)
231        sys.stdout = sys.__stdout__
232        self.assertTrue('tf msg' in capture_output.getvalue())
233
234    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_process_connection')
235    @mock.patch('select.select')
236    def test_start_monitor_2_connection(self, mock_select, mock_process):
237        """Test _start_monitor method."""
238        mock_server = mock.Mock()
239        mock_subproc = mock.Mock()
240        mock_reporter = mock.Mock()
241        mock_conn1 = mock.Mock()
242        mock_conn2 = mock.Mock()
243        mock_server.accept.side_effect = [(mock_conn1, 'addr 1'),
244                                          (mock_conn2, 'addr 2')]
245        mock_select.side_effect = [([mock_server], None, None),
246                                   ([mock_server], None, None),
247                                   ([mock_conn1], None, None),
248                                   ([mock_conn2], None, None),
249                                   ([mock_conn1], None, None),
250                                   ([mock_conn2], None, None)]
251        mock_process.side_effect = ['abc', 'def', False, False]
252        mock_subproc.poll.side_effect = [None, None, None, None,
253                                         None, True]
254        self.tr._start_monitor(mock_server, mock_subproc, mock_reporter)
255        self.assertEqual(mock_process.call_count, 4)
256        calls = [mock.call.accept(), mock.call.close()]
257        mock_server.assert_has_calls(calls)
258        mock_conn1.assert_has_calls([mock.call.close()])
259        mock_conn2.assert_has_calls([mock.call.close()])
260
261    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_process_connection')
262    @mock.patch('select.select')
263    def test_start_monitor_tf_exit_before_2nd_connection(self,
264                                                         mock_select,
265                                                         mock_process):
266        """Test _start_monitor method."""
267        mock_server = mock.Mock()
268        mock_subproc = mock.Mock()
269        mock_reporter = mock.Mock()
270        mock_conn1 = mock.Mock()
271        mock_conn2 = mock.Mock()
272        mock_server.accept.side_effect = [(mock_conn1, 'addr 1'),
273                                          (mock_conn2, 'addr 2')]
274        mock_select.side_effect = [([mock_server], None, None),
275                                   ([mock_server], None, None),
276                                   ([mock_conn1], None, None),
277                                   ([mock_conn2], None, None),
278                                   ([mock_conn1], None, None),
279                                   ([mock_conn2], None, None)]
280        mock_process.side_effect = ['abc', 'def', False, False]
281        # TF exit early but have not processed data in socket buffer.
282        mock_subproc.poll.side_effect = [None, None, True, True,
283                                         True, True]
284        self.tr._start_monitor(mock_server, mock_subproc, mock_reporter)
285        self.assertEqual(mock_process.call_count, 4)
286        calls = [mock.call.accept(), mock.call.close()]
287        mock_server.assert_has_calls(calls)
288        mock_conn1.assert_has_calls([mock.call.close()])
289        mock_conn2.assert_has_calls([mock.call.close()])
290
291
292    def test_start_socket_server(self):
293        """Test start_socket_server method."""
294        server = self.tr._start_socket_server()
295        host, port = server.getsockname()
296        self.assertEqual(host, atf_tr.SOCKET_HOST)
297        self.assertLessEqual(port, 65535)
298        self.assertGreaterEqual(port, 1024)
299        server.close()
300
301    @mock.patch('os.path.exists')
302    @mock.patch.dict('os.environ', {'APE_API_KEY':'/tmp/123.json'})
303    def test_try_set_gts_authentication_key_is_set_by_user(self, mock_exist):
304        """Test try_set_authentication_key_is_set_by_user method."""
305        # Test key is set by user.
306        self.tr._try_set_gts_authentication_key()
307        mock_exist.assert_not_called()
308
309    @mock.patch('os.path.join', return_value='/tmp/file_not_exist.json')
310    def test_try_set_gts_authentication_key_not_set(self, _):
311        """Test try_set_authentication_key_not_set method."""
312        # Delete the environment variable if it's set. This is fine for this
313        # method because it's for validating the APE_API_KEY isn't set.
314        if os.environ.get('APE_API_KEY'):
315            del os.environ['APE_API_KEY']
316        self.tr._try_set_gts_authentication_key()
317        self.assertEqual(os.environ.get('APE_API_KEY'), None)
318
319    @mock.patch.object(event_handler.EventHandler, 'process_event')
320    def test_process_connection(self, mock_pe):
321        """Test _process_connection method."""
322        mock_socket = mock.Mock()
323        for name, data in EVENTS_NORMAL:
324            data_map = {mock_socket: ''}
325            socket_data = '%s %s' % (name, json.dumps(data))
326            mock_socket.recv.return_value = socket_data
327            self.tr._process_connection(data_map, mock_socket, mock_pe)
328
329        calls = [mock.call.process_event(name, data) for name, data in EVENTS_NORMAL]
330        mock_pe.assert_has_calls(calls)
331        mock_socket.recv.return_value = ''
332        self.assertFalse(self.tr._process_connection(data_map, mock_socket, mock_pe))
333
334    @mock.patch.object(event_handler.EventHandler, 'process_event')
335    def test_process_connection_multiple_lines_in_single_recv(self, mock_pe):
336        """Test _process_connection when recv reads multiple lines in one go."""
337        mock_socket = mock.Mock()
338        squashed_events = '\n'.join(['%s %s' % (name, json.dumps(data))
339                                     for name, data in EVENTS_NORMAL])
340        socket_data = [squashed_events, '']
341        mock_socket.recv.side_effect = socket_data
342        data_map = {mock_socket: ''}
343        self.tr._process_connection(data_map, mock_socket, mock_pe)
344        calls = [mock.call.process_event(name, data) for name, data in EVENTS_NORMAL]
345        mock_pe.assert_has_calls(calls)
346
347    @mock.patch.object(event_handler.EventHandler, 'process_event')
348    def test_process_connection_with_buffering(self, mock_pe):
349        """Test _process_connection when events overflow socket buffer size."""
350        mock_socket = mock.Mock()
351        module_events = [EVENTS_NORMAL[0], EVENTS_NORMAL[-1]]
352        socket_events = ['%s %s' % (name, json.dumps(data))
353                         for name, data in module_events]
354        # test try-block code by breaking apart first event after first }
355        index = socket_events[0].index('}') + 1
356        socket_data = [socket_events[0][:index], socket_events[0][index:]]
357        # test non-try block buffering with second event
358        socket_data.extend([socket_events[1][:-4], socket_events[1][-4:], ''])
359        mock_socket.recv.side_effect = socket_data
360        data_map = {mock_socket: ''}
361        self.tr._process_connection(data_map, mock_socket, mock_pe)
362        self.tr._process_connection(data_map, mock_socket, mock_pe)
363        self.tr._process_connection(data_map, mock_socket, mock_pe)
364        self.tr._process_connection(data_map, mock_socket, mock_pe)
365        calls = [mock.call.process_event(name, data) for name, data in module_events]
366        mock_pe.assert_has_calls(calls)
367
368    @mock.patch.object(event_handler.EventHandler, 'process_event')
369    def test_process_connection_with_not_completed_event_data(self, mock_pe):
370        """Test _process_connection when event have \n prefix."""
371        mock_socket = mock.Mock()
372        mock_socket.recv.return_value = ('\n%s %s'
373                                         %(EVENTS_NORMAL[0][0],
374                                           json.dumps(EVENTS_NORMAL[0][1])))
375        data_map = {mock_socket: ''}
376        self.tr._process_connection(data_map, mock_socket, mock_pe)
377        calls = [mock.call.process_event(EVENTS_NORMAL[0][0],
378                                         EVENTS_NORMAL[0][1])]
379        mock_pe.assert_has_calls(calls)
380
381    @mock.patch('os.environ.get', return_value=None)
382    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
383    @mock.patch('atest_utils.get_result_server_args')
384    def test_generate_run_commands_without_serial_env(self, mock_resultargs, mock_mertrics, _):
385        """Test generate_run_command method."""
386        # Basic Run Cmd
387        mock_resultargs.return_value = []
388        mock_mertrics.return_value = ''
389        unittest_utils.assert_strict_equal(
390            self,
391            self.tr.generate_run_commands([], {}),
392            [RUN_CMD.format(metrics='',
393                            serial='',
394                            tf_customize_template='')])
395        mock_mertrics.return_value = METRICS_DIR
396        unittest_utils.assert_strict_equal(
397            self,
398            self.tr.generate_run_commands([], {}),
399            [RUN_CMD.format(metrics=METRICS_DIR_ARG,
400                            serial='',
401                            tf_customize_template='')])
402        # Run cmd with result server args.
403        result_arg = '--result_arg'
404        mock_resultargs.return_value = [result_arg]
405        mock_mertrics.return_value = ''
406        unittest_utils.assert_strict_equal(
407            self,
408            self.tr.generate_run_commands([], {}),
409            [RUN_CMD.format(metrics='',
410                            serial='',
411                            tf_customize_template='') + ' ' + result_arg])
412
413    @mock.patch('os.environ.get')
414    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
415    @mock.patch('atest_utils.get_result_server_args')
416    def test_generate_run_commands_with_serial_env(self, mock_resultargs, mock_mertrics, mock_env):
417        """Test generate_run_command method."""
418        # Basic Run Cmd
419        env_device_serial = 'env-device-0'
420        mock_resultargs.return_value = []
421        mock_mertrics.return_value = ''
422        mock_env.return_value = env_device_serial
423        env_serial_arg = ' --serial %s' % env_device_serial
424        # Serial env be set and without --serial arg.
425        unittest_utils.assert_strict_equal(
426            self,
427            self.tr.generate_run_commands([], {}),
428            [RUN_CMD.format(metrics='',
429                            serial=env_serial_arg,
430                            tf_customize_template='')])
431        # Serial env be set but with --serial arg.
432        arg_device_serial = 'arg-device-0'
433        arg_serial_arg = ' --serial %s' % arg_device_serial
434        unittest_utils.assert_strict_equal(
435            self,
436            self.tr.generate_run_commands([], {constants.SERIAL:arg_device_serial}),
437            [RUN_CMD.format(metrics='',
438                            serial=arg_serial_arg,
439                            tf_customize_template='')])
440        # Serial env be set but with -n arg
441        unittest_utils.assert_strict_equal(
442            self,
443            self.tr.generate_run_commands([], {constants.HOST: True}),
444            [RUN_CMD.format(metrics='',
445                            serial='',
446                            tf_customize_template='') +
447             ' -n --prioritize-host-config --skip-host-arch-check'])
448
449
450    def test_flatten_test_filters(self):
451        """Test _flatten_test_filters method."""
452        # No Flattening
453        filters = self.tr._flatten_test_filters({uc.CLASS_FILTER})
454        unittest_utils.assert_strict_equal(self, frozenset([uc.CLASS_FILTER]),
455                                           filters)
456        filters = self.tr._flatten_test_filters({CLASS2_FILTER})
457        unittest_utils.assert_strict_equal(
458            self, frozenset([CLASS2_FILTER]), filters)
459        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER})
460        unittest_utils.assert_strict_equal(
461            self, frozenset([uc.METHOD_FILTER]), filters)
462        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
463                                                 CLASS2_METHOD_FILTER})
464        unittest_utils.assert_strict_equal(
465            self, frozenset([uc.METHOD_FILTER, CLASS2_METHOD_FILTER]), filters)
466        # Flattening
467        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
468                                                 METHOD2_FILTER})
469        unittest_utils.assert_strict_equal(
470            self, filters, frozenset([uc.FLAT_METHOD_FILTER]))
471        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
472                                                 METHOD2_FILTER,
473                                                 CLASS2_METHOD_FILTER,})
474        unittest_utils.assert_strict_equal(
475            self, filters, frozenset([uc.FLAT_METHOD_FILTER,
476                                      CLASS2_METHOD_FILTER]))
477
478    def test_flatten_test_infos(self):
479        """Test _flatten_test_infos method."""
480        # No Flattening
481        test_infos = self.tr._flatten_test_infos({uc.MODULE_INFO})
482        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
483                                                  {uc.MODULE_INFO})
484
485        test_infos = self.tr._flatten_test_infos([uc.MODULE_INFO, MODULE2_INFO])
486        unittest_utils.assert_equal_testinfo_sets(
487            self, test_infos, {uc.MODULE_INFO, MODULE2_INFO})
488
489        test_infos = self.tr._flatten_test_infos({CLASS1_INFO})
490        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
491                                                  {CLASS1_INFO})
492
493        test_infos = self.tr._flatten_test_infos({uc.INT_INFO})
494        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
495                                                  {uc.INT_INFO})
496
497        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO})
498        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
499                                                  {uc.METHOD_INFO})
500
501        # Flattening
502        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, CLASS2_INFO})
503        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
504                                                  {FLAT_CLASS_INFO})
505
506        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, uc.INT_INFO,
507                                                  CLASS2_INFO})
508        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
509                                                  {uc.INT_INFO,
510                                                   FLAT_CLASS_INFO})
511
512        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, uc.MODULE_INFO,
513                                                  CLASS2_INFO})
514        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
515                                                  {CLASS1_CLASS2_MODULE_INFO})
516
517        test_infos = self.tr._flatten_test_infos({MODULE2_INFO, uc.INT_INFO,
518                                                  CLASS1_INFO, CLASS2_INFO,
519                                                  uc.GTF_INT_INFO})
520        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
521                                                  {uc.INT_INFO, uc.GTF_INT_INFO,
522                                                   FLAT_CLASS_INFO,
523                                                   MODULE2_INFO})
524
525        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO,
526                                                  CLASS2_METHOD_INFO})
527        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
528                                                  {METHOD_AND_CLASS2_METHOD})
529
530        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO, METHOD2_INFO,
531                                                  CLASS2_METHOD_INFO})
532        unittest_utils.assert_equal_testinfo_sets(
533            self, test_infos, {METHOD_METHOD2_AND_CLASS2_METHOD})
534        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO, METHOD2_INFO,
535                                                  CLASS2_METHOD_INFO,
536                                                  MODULE2_INFO,
537                                                  uc.INT_INFO})
538        unittest_utils.assert_equal_testinfo_sets(
539            self, test_infos, {uc.INT_INFO, MODULE2_INFO,
540                               METHOD_METHOD2_AND_CLASS2_METHOD})
541
542        test_infos = self.tr._flatten_test_infos({CLASS3_INFO, CLASS4_INFO})
543        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
544                                                  {FLAT2_CLASS_INFO})
545
546    def test_create_test_args(self):
547        """Test _create_test_args method."""
548        # Only compile '--skip-loading-config-jar' in TF if it's not
549        # INTEGRATION finder or the finder property isn't set.
550        args = self.tr._create_test_args([MOD_INFO])
551        self.assertTrue(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
552
553        args = self.tr._create_test_args([INT_INFO])
554        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
555
556        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER])
557        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
558
559        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER, INT_INFO])
560        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
561
562        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER])
563        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
564
565        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER, INT_INFO, MOD_INFO])
566        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
567
568    @mock.patch('os.environ.get', return_value=None)
569    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
570    @mock.patch('atest_utils.get_result_server_args')
571    def test_generate_run_commands_collect_tests_only(self,
572                                                      mock_resultargs,
573                                                      mock_mertrics, _):
574        """Test generate_run_command method."""
575        # Testing  without collect-tests-only
576        mock_resultargs.return_value = []
577        mock_mertrics.return_value = ''
578        extra_args = {}
579        unittest_utils.assert_strict_equal(
580            self,
581            self.tr.generate_run_commands([], extra_args),
582            [RUN_CMD.format(
583                metrics='',
584                serial='',
585                tf_customize_template='')])
586        # Testing  with collect-tests-only
587        mock_resultargs.return_value = []
588        mock_mertrics.return_value = ''
589        extra_args = {constants.COLLECT_TESTS_ONLY: True}
590        unittest_utils.assert_strict_equal(
591            self,
592            self.tr.generate_run_commands([], extra_args),
593            [RUN_CMD.format(
594                metrics='',
595                serial=' --collect-tests-only',
596                tf_customize_template='')])
597
598
599    @mock.patch('os.environ.get', return_value=None)
600    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
601    @mock.patch('atest_utils.get_result_server_args')
602    def test_generate_run_commands_with_tf_template(self, mock_resultargs, mock_mertrics, _):
603        """Test generate_run_command method."""
604        tf_tmplate_key1 = 'tf_tmplate_key1'
605        tf_tmplate_val1 = 'tf_tmplate_val1'
606        tf_tmplate_key2 = 'tf_tmplate_key2'
607        tf_tmplate_val2 = 'tf_tmplate_val2'
608        # Testing with only one tradefed template command
609        mock_resultargs.return_value = []
610        mock_mertrics.return_value = ''
611        extra_args = {constants.TF_TEMPLATE:
612                          ['{}={}'.format(tf_tmplate_key1,
613                                          tf_tmplate_val1)]}
614        unittest_utils.assert_strict_equal(
615            self,
616            self.tr.generate_run_commands([], extra_args),
617            [RUN_CMD.format(
618                metrics='',
619                serial='',
620                tf_customize_template=
621                '--template:map {}={}').format(tf_tmplate_key1,
622                                               tf_tmplate_val1)])
623        # Testing with two tradefed template commands
624        extra_args = {constants.TF_TEMPLATE:
625                          ['{}={}'.format(tf_tmplate_key1,
626                                          tf_tmplate_val1),
627                           '{}={}'.format(tf_tmplate_key2,
628                                          tf_tmplate_val2)]}
629        unittest_utils.assert_strict_equal(
630            self,
631            self.tr.generate_run_commands([], extra_args),
632            [RUN_CMD.format(
633                metrics='',
634                serial='',
635                tf_customize_template=
636                '--template:map {}={} --template:map {}={}').format(
637                    tf_tmplate_key1,
638                    tf_tmplate_val1,
639                    tf_tmplate_key2,
640                    tf_tmplate_val2)])
641
642
643if __name__ == '__main__':
644    unittest.main()
645