1#!/usr/bin/env python
2#
3# Copyright 2018, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Unittests for atest_tf_test_runner."""
18
19import os
20import sys
21import tempfile
22import unittest
23import json
24import mock
25
26# pylint: disable=import-error
27import constants
28import unittest_constants as uc
29import unittest_utils
30import atest_tf_test_runner as atf_tr
31import event_handler
32from test_finders import test_info
33
34if sys.version_info[0] == 2:
35    from StringIO import StringIO
36else:
37    from io import StringIO
38
39#pylint: disable=protected-access
40#pylint: disable=invalid-name
41TEST_INFO_DIR = '/tmp/atest_run_1510085893_pi_Nbi'
42METRICS_DIR = '%s/baseline-metrics' % TEST_INFO_DIR
43METRICS_DIR_ARG = '--metrics-folder %s ' % METRICS_DIR
44# TODO(147567606): Replace {serial} with {extra_args} for general extra
45# arguments testing.
46RUN_CMD_ARGS = '{metrics}--log-level WARN{serial}'
47LOG_ARGS = atf_tr.AtestTradefedTestRunner._LOG_ARGS.format(
48    log_path=os.path.join(TEST_INFO_DIR, atf_tr.LOG_FOLDER_NAME))
49RUN_CMD = atf_tr.AtestTradefedTestRunner._RUN_CMD.format(
50    exe=atf_tr.AtestTradefedTestRunner.EXECUTABLE,
51    template=atf_tr.AtestTradefedTestRunner._TF_TEMPLATE,
52    tf_customize_template='{tf_customize_template}',
53    args=RUN_CMD_ARGS,
54    log_args=LOG_ARGS)
55FULL_CLASS2_NAME = 'android.jank.cts.ui.SomeOtherClass'
56CLASS2_FILTER = test_info.TestFilter(FULL_CLASS2_NAME, frozenset())
57METHOD2_FILTER = test_info.TestFilter(uc.FULL_CLASS_NAME, frozenset([uc.METHOD2_NAME]))
58MODULE_ARG1 = [(constants.TF_INCLUDE_FILTER_OPTION, "A"),
59               (constants.TF_INCLUDE_FILTER_OPTION, "B")]
60MODULE_ARG2 = []
61CLASS2_METHOD_FILTER = test_info.TestFilter(FULL_CLASS2_NAME,
62                                            frozenset([uc.METHOD_NAME, uc.METHOD2_NAME]))
63MODULE2_INFO = test_info.TestInfo(uc.MODULE2_NAME,
64                                  atf_tr.AtestTradefedTestRunner.NAME,
65                                  set(),
66                                  data={constants.TI_REL_CONFIG: uc.CONFIG2_FILE,
67                                        constants.TI_FILTER: frozenset()})
68CLASS1_BUILD_TARGETS = {'class_1_build_target'}
69CLASS1_INFO = test_info.TestInfo(uc.MODULE_NAME,
70                                 atf_tr.AtestTradefedTestRunner.NAME,
71                                 CLASS1_BUILD_TARGETS,
72                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
73                                       constants.TI_FILTER: frozenset([uc.CLASS_FILTER])})
74CLASS2_BUILD_TARGETS = {'class_2_build_target'}
75CLASS2_INFO = test_info.TestInfo(uc.MODULE_NAME,
76                                 atf_tr.AtestTradefedTestRunner.NAME,
77                                 CLASS2_BUILD_TARGETS,
78                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
79                                       constants.TI_FILTER: frozenset([CLASS2_FILTER])})
80CLASS3_BUILD_TARGETS = {'class_3_build_target'}
81CLASS3_INFO = test_info.TestInfo(uc.MODULE_NAME,
82                                 atf_tr.AtestTradefedTestRunner.NAME,
83                                 CLASS3_BUILD_TARGETS,
84                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
85                                       constants.TI_FILTER: frozenset(),
86                                       constants.TI_MODULE_ARG: MODULE_ARG1})
87CLASS4_BUILD_TARGETS = {'class_4_build_target'}
88CLASS4_INFO = test_info.TestInfo(uc.MODULE_NAME,
89                                 atf_tr.AtestTradefedTestRunner.NAME,
90                                 CLASS4_BUILD_TARGETS,
91                                 data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
92                                       constants.TI_FILTER: frozenset(),
93                                       constants.TI_MODULE_ARG: MODULE_ARG2})
94CLASS1_CLASS2_MODULE_INFO = test_info.TestInfo(
95    uc.MODULE_NAME,
96    atf_tr.AtestTradefedTestRunner.NAME,
97    uc.MODULE_BUILD_TARGETS | CLASS1_BUILD_TARGETS | CLASS2_BUILD_TARGETS,
98    uc.MODULE_DATA)
99FLAT_CLASS_INFO = test_info.TestInfo(
100    uc.MODULE_NAME,
101    atf_tr.AtestTradefedTestRunner.NAME,
102    CLASS1_BUILD_TARGETS | CLASS2_BUILD_TARGETS,
103    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
104          constants.TI_FILTER: frozenset([uc.CLASS_FILTER, CLASS2_FILTER])})
105FLAT2_CLASS_INFO = test_info.TestInfo(
106    uc.MODULE_NAME,
107    atf_tr.AtestTradefedTestRunner.NAME,
108    CLASS3_BUILD_TARGETS | CLASS4_BUILD_TARGETS,
109    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
110          constants.TI_FILTER: frozenset(),
111          constants.TI_MODULE_ARG: MODULE_ARG1 + MODULE_ARG2})
112GTF_INT_CONFIG = os.path.join(uc.GTF_INT_DIR, uc.GTF_INT_NAME + '.xml')
113CLASS2_METHOD_INFO = test_info.TestInfo(
114    uc.MODULE_NAME,
115    atf_tr.AtestTradefedTestRunner.NAME,
116    set(),
117    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
118          constants.TI_FILTER:
119              frozenset([test_info.TestFilter(
120                  FULL_CLASS2_NAME, frozenset([uc.METHOD_NAME, uc.METHOD2_NAME]))])})
121METHOD_AND_CLASS2_METHOD = test_info.TestInfo(
122    uc.MODULE_NAME,
123    atf_tr.AtestTradefedTestRunner.NAME,
124    uc.MODULE_BUILD_TARGETS,
125    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
126          constants.TI_FILTER: frozenset([uc.METHOD_FILTER, CLASS2_METHOD_FILTER])})
127METHOD_METHOD2_AND_CLASS2_METHOD = test_info.TestInfo(
128    uc.MODULE_NAME,
129    atf_tr.AtestTradefedTestRunner.NAME,
130    uc.MODULE_BUILD_TARGETS,
131    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
132          constants.TI_FILTER: frozenset([uc.FLAT_METHOD_FILTER, CLASS2_METHOD_FILTER])})
133METHOD2_INFO = test_info.TestInfo(
134    uc.MODULE_NAME,
135    atf_tr.AtestTradefedTestRunner.NAME,
136    set(),
137    data={constants.TI_REL_CONFIG: uc.CONFIG_FILE,
138          constants.TI_FILTER: frozenset([METHOD2_FILTER])})
139
140INT_INFO = test_info.TestInfo(
141    uc.INT_NAME,
142    atf_tr.AtestTradefedTestRunner.NAME,
143    set(),
144    test_finder='INTEGRATION')
145
146MOD_INFO = test_info.TestInfo(
147    uc.MODULE_NAME,
148    atf_tr.AtestTradefedTestRunner.NAME,
149    set(),
150    test_finder='MODULE')
151
152MOD_INFO_NO_TEST_FINDER = test_info.TestInfo(
153    uc.MODULE_NAME,
154    atf_tr.AtestTradefedTestRunner.NAME,
155    set())
156
157EVENTS_NORMAL = [
158    ('TEST_MODULE_STARTED', {
159        'moduleContextFileName':'serial-util1146216{974}2772610436.ser',
160        'moduleName':'someTestModule'}),
161    ('TEST_RUN_STARTED', {'testCount': 2}),
162    ('TEST_STARTED', {'start_time':52, 'className':'someClassName',
163                      'testName':'someTestName'}),
164    ('TEST_ENDED', {'end_time':1048, 'className':'someClassName',
165                    'testName':'someTestName'}),
166    ('TEST_STARTED', {'start_time':48, 'className':'someClassName2',
167                      'testName':'someTestName2'}),
168    ('TEST_FAILED', {'className':'someClassName2', 'testName':'someTestName2',
169                     'trace': 'someTrace'}),
170    ('TEST_ENDED', {'end_time':9876450, 'className':'someClassName2',
171                    'testName':'someTestName2'}),
172    ('TEST_RUN_ENDED', {}),
173    ('TEST_MODULE_ENDED', {'foo': 'bar'}),
174]
175
176class AtestTradefedTestRunnerUnittests(unittest.TestCase):
177    """Unit tests for atest_tf_test_runner.py"""
178
179    def setUp(self):
180        self.tr = atf_tr.AtestTradefedTestRunner(results_dir=TEST_INFO_DIR)
181
182    def tearDown(self):
183        mock.patch.stopall()
184
185    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
186                       '_start_socket_server')
187    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
188                       'run')
189    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
190                       '_create_test_args', return_value=['some_args'])
191    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
192                       'generate_run_commands', return_value='some_cmd')
193    @mock.patch.object(atf_tr.AtestTradefedTestRunner,
194                       '_process_connection', return_value=None)
195    @mock.patch('select.select')
196    @mock.patch('os.killpg', return_value=None)
197    @mock.patch('os.getpgid', return_value=None)
198    @mock.patch('signal.signal', return_value=None)
199    def test_run_tests_pretty(self, _signal, _pgid, _killpg, mock_select,
200                              _process, _run_cmd, _test_args,
201                              mock_run, mock_start_socket_server):
202        """Test _run_tests_pretty method."""
203        mock_subproc = mock.Mock()
204        mock_run.return_value = mock_subproc
205        mock_subproc.returncode = 0
206        mock_subproc.poll.side_effect = [True, True, None]
207        mock_server = mock.Mock()
208        mock_server.getsockname.return_value = ('', '')
209        mock_start_socket_server.return_value = mock_server
210        mock_reporter = mock.Mock()
211
212        # Test no early TF exit
213        mock_conn = mock.Mock()
214        mock_server.accept.return_value = (mock_conn, 'some_addr')
215        mock_server.close.return_value = True
216        mock_select.side_effect = [([mock_server], None, None),
217                                   ([mock_conn], None, None)]
218        self.tr.run_tests_pretty([MODULE2_INFO], {}, mock_reporter)
219
220        # Test early TF exit
221        tmp_file = tempfile.NamedTemporaryFile()
222        with open(tmp_file.name, 'w') as f:
223            f.write("tf msg")
224        self.tr.test_log_file = tmp_file
225        mock_select.side_effect = [([], None, None)]
226        mock_subproc.poll.side_effect = None
227        capture_output = StringIO()
228        sys.stdout = capture_output
229        self.assertRaises(atf_tr.TradeFedExitError, self.tr.run_tests_pretty,
230                          [MODULE2_INFO], {}, mock_reporter)
231        sys.stdout = sys.__stdout__
232        self.assertTrue('tf msg' in capture_output.getvalue())
233
234    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_process_connection')
235    @mock.patch('select.select')
236    def test_start_monitor_2_connection(self, mock_select, mock_process):
237        """Test _start_monitor method."""
238        mock_server = mock.Mock()
239        mock_subproc = mock.Mock()
240        mock_reporter = mock.Mock()
241        mock_conn1 = mock.Mock()
242        mock_conn2 = mock.Mock()
243        mock_server.accept.side_effect = [(mock_conn1, 'addr 1'),
244                                          (mock_conn2, 'addr 2')]
245        mock_select.side_effect = [([mock_server], None, None),
246                                   ([mock_server], None, None),
247                                   ([mock_conn1], None, None),
248                                   ([mock_conn2], None, None),
249                                   ([mock_conn1], None, None),
250                                   ([mock_conn2], None, None)]
251        mock_process.side_effect = ['abc', 'def', False, False]
252        mock_subproc.poll.side_effect = [None, None, None, None,
253                                         None, True]
254        self.tr._start_monitor(mock_server, mock_subproc, mock_reporter)
255        self.assertEqual(mock_process.call_count, 4)
256        calls = [mock.call.accept(), mock.call.close()]
257        mock_server.assert_has_calls(calls)
258        mock_conn1.assert_has_calls([mock.call.close()])
259        mock_conn2.assert_has_calls([mock.call.close()])
260
261    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_process_connection')
262    @mock.patch('select.select')
263    def test_start_monitor_tf_exit_before_2nd_connection(self,
264                                                         mock_select,
265                                                         mock_process):
266        """Test _start_monitor method."""
267        mock_server = mock.Mock()
268        mock_subproc = mock.Mock()
269        mock_reporter = mock.Mock()
270        mock_conn1 = mock.Mock()
271        mock_conn2 = mock.Mock()
272        mock_server.accept.side_effect = [(mock_conn1, 'addr 1'),
273                                          (mock_conn2, 'addr 2')]
274        mock_select.side_effect = [([mock_server], None, None),
275                                   ([mock_server], None, None),
276                                   ([mock_conn1], None, None),
277                                   ([mock_conn2], None, None),
278                                   ([mock_conn1], None, None),
279                                   ([mock_conn2], None, None)]
280        mock_process.side_effect = ['abc', 'def', False, False]
281        # TF exit early but have not processed data in socket buffer.
282        mock_subproc.poll.side_effect = [None, None, True, True,
283                                         True, True]
284        self.tr._start_monitor(mock_server, mock_subproc, mock_reporter)
285        self.assertEqual(mock_process.call_count, 4)
286        calls = [mock.call.accept(), mock.call.close()]
287        mock_server.assert_has_calls(calls)
288        mock_conn1.assert_has_calls([mock.call.close()])
289        mock_conn2.assert_has_calls([mock.call.close()])
290
291
292    def test_start_socket_server(self):
293        """Test start_socket_server method."""
294        server = self.tr._start_socket_server()
295        host, port = server.getsockname()
296        self.assertEqual(host, atf_tr.SOCKET_HOST)
297        self.assertLessEqual(port, 65535)
298        self.assertGreaterEqual(port, 1024)
299        server.close()
300
301    @mock.patch('os.path.exists')
302    @mock.patch.dict('os.environ', {'APE_API_KEY':'/tmp/123.json'})
303    def test_try_set_gts_authentication_key_is_set_by_user(self, mock_exist):
304        """Test try_set_authentication_key_is_set_by_user method."""
305        # Test key is set by user.
306        self.tr._try_set_gts_authentication_key()
307        mock_exist.assert_not_called()
308
309    @mock.patch('os.path.join', return_value='/tmp/file_not_exist.json')
310    def test_try_set_gts_authentication_key_not_set(self, _):
311        """Test try_set_authentication_key_not_set method."""
312        # Delete the environment variable if it's set. This is fine for this
313        # method because it's for validating the APE_API_KEY isn't set.
314        if os.environ.get('APE_API_KEY'):
315            del os.environ['APE_API_KEY']
316        self.tr._try_set_gts_authentication_key()
317        self.assertEqual(os.environ.get('APE_API_KEY'), None)
318
319    @mock.patch.object(event_handler.EventHandler, 'process_event')
320    def test_process_connection(self, mock_pe):
321        """Test _process_connection method."""
322        mock_socket = mock.Mock()
323        for name, data in EVENTS_NORMAL:
324            datas = {mock_socket: ''}
325            socket_data = '%s %s' % (name, json.dumps(data))
326            mock_socket.recv.return_value = socket_data
327            self.tr._process_connection(datas, mock_socket, mock_pe)
328
329        calls = [mock.call.process_event(name, data) for name, data in EVENTS_NORMAL]
330        mock_pe.assert_has_calls(calls)
331        mock_socket.recv.return_value = ''
332        self.assertFalse(self.tr._process_connection(datas, mock_socket, mock_pe))
333
334    @mock.patch.object(event_handler.EventHandler, 'process_event')
335    def test_process_connection_multiple_lines_in_single_recv(self, mock_pe):
336        """Test _process_connection when recv reads multiple lines in one go."""
337        mock_socket = mock.Mock()
338        squashed_events = '\n'.join(['%s %s' % (name, json.dumps(data))
339                                     for name, data in EVENTS_NORMAL])
340        socket_data = [squashed_events, '']
341        mock_socket.recv.side_effect = socket_data
342        datas = {mock_socket: ''}
343        self.tr._process_connection(datas, mock_socket, mock_pe)
344        calls = [mock.call.process_event(name, data) for name, data in EVENTS_NORMAL]
345        mock_pe.assert_has_calls(calls)
346
347    @mock.patch.object(event_handler.EventHandler, 'process_event')
348    def test_process_connection_with_buffering(self, mock_pe):
349        """Test _process_connection when events overflow socket buffer size"""
350        mock_socket = mock.Mock()
351        module_events = [EVENTS_NORMAL[0], EVENTS_NORMAL[-1]]
352        socket_events = ['%s %s' % (name, json.dumps(data))
353                         for name, data in module_events]
354        # test try-block code by breaking apart first event after first }
355        index = socket_events[0].index('}') + 1
356        socket_data = [socket_events[0][:index], socket_events[0][index:]]
357        # test non-try block buffering with second event
358        socket_data.extend([socket_events[1][:-4], socket_events[1][-4:], ''])
359        mock_socket.recv.side_effect = socket_data
360        datas = {mock_socket: ''}
361        self.tr._process_connection(datas, mock_socket, mock_pe)
362        self.tr._process_connection(datas, mock_socket, mock_pe)
363        self.tr._process_connection(datas, mock_socket, mock_pe)
364        self.tr._process_connection(datas, mock_socket, mock_pe)
365        calls = [mock.call.process_event(name, data) for name, data in module_events]
366        mock_pe.assert_has_calls(calls)
367
368    @mock.patch.object(event_handler.EventHandler, 'process_event')
369    def test_process_connection_with_not_completed_event_data(self, mock_pe):
370        """Test _process_connection when event have \n prefix."""
371        mock_socket = mock.Mock()
372        mock_socket.recv.return_value = ('\n%s %s'
373                                         %(EVENTS_NORMAL[0][0],
374                                           json.dumps(EVENTS_NORMAL[0][1])))
375        datas = {mock_socket: ''}
376        self.tr._process_connection(datas, mock_socket, mock_pe)
377        calls = [mock.call.process_event(EVENTS_NORMAL[0][0],
378                                         EVENTS_NORMAL[0][1])]
379        mock_pe.assert_has_calls(calls)
380
381    @mock.patch('os.environ.get', return_value=None)
382    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
383    @mock.patch('atest_utils.get_result_server_args')
384    def test_generate_run_commands_without_serial_env(self, mock_resultargs, mock_mertrics, _):
385        """Test generate_run_command method."""
386        # Basic Run Cmd
387        mock_resultargs.return_value = []
388        mock_mertrics.return_value = ''
389        unittest_utils.assert_strict_equal(
390            self,
391            self.tr.generate_run_commands([], {}),
392            [RUN_CMD.format(metrics='',
393                            serial='',
394                            tf_customize_template='')])
395        mock_mertrics.return_value = METRICS_DIR
396        unittest_utils.assert_strict_equal(
397            self,
398            self.tr.generate_run_commands([], {}),
399            [RUN_CMD.format(metrics=METRICS_DIR_ARG,
400                            serial='',
401                            tf_customize_template='')])
402        # Run cmd with result server args.
403        result_arg = '--result_arg'
404        mock_resultargs.return_value = [result_arg]
405        mock_mertrics.return_value = ''
406        unittest_utils.assert_strict_equal(
407            self,
408            self.tr.generate_run_commands([], {}),
409            [RUN_CMD.format(metrics='',
410                            serial='',
411                            tf_customize_template='') + ' ' + result_arg])
412
413    @mock.patch('os.environ.get')
414    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
415    @mock.patch('atest_utils.get_result_server_args')
416    def test_generate_run_commands_with_serial_env(self, mock_resultargs, mock_mertrics, mock_env):
417        """Test generate_run_command method."""
418        # Basic Run Cmd
419        env_device_serial = 'env-device-0'
420        mock_resultargs.return_value = []
421        mock_mertrics.return_value = ''
422        mock_env.return_value = env_device_serial
423        env_serial_arg = ' --serial %s' % env_device_serial
424        # Serial env be set and without --serial arg.
425        unittest_utils.assert_strict_equal(
426            self,
427            self.tr.generate_run_commands([], {}),
428            [RUN_CMD.format(metrics='',
429                            serial=env_serial_arg,
430                            tf_customize_template='')])
431        # Serial env be set but with --serial arg.
432        arg_device_serial = 'arg-device-0'
433        arg_serial_arg = ' --serial %s' % arg_device_serial
434        unittest_utils.assert_strict_equal(
435            self,
436            self.tr.generate_run_commands([], {constants.SERIAL:arg_device_serial}),
437            [RUN_CMD.format(metrics='',
438                            serial=arg_serial_arg,
439                            tf_customize_template='')])
440        # Serial env be set but with -n arg
441        unittest_utils.assert_strict_equal(
442            self,
443            self.tr.generate_run_commands([], {constants.HOST: True}),
444            [RUN_CMD.format(metrics='',
445                            serial='',
446                            tf_customize_template='') +
447             ' -n --prioritize-host-config --skip-host-arch-check'])
448
449
450    def test_flatten_test_filters(self):
451        """Test _flatten_test_filters method."""
452        # No Flattening
453        filters = self.tr._flatten_test_filters({uc.CLASS_FILTER})
454        unittest_utils.assert_strict_equal(self, frozenset([uc.CLASS_FILTER]),
455                                           filters)
456        filters = self.tr._flatten_test_filters({CLASS2_FILTER})
457        unittest_utils.assert_strict_equal(
458            self, frozenset([CLASS2_FILTER]), filters)
459        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER})
460        unittest_utils.assert_strict_equal(
461            self, frozenset([uc.METHOD_FILTER]), filters)
462        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
463                                                 CLASS2_METHOD_FILTER})
464        unittest_utils.assert_strict_equal(
465            self, frozenset([uc.METHOD_FILTER, CLASS2_METHOD_FILTER]), filters)
466        # Flattening
467        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
468                                                 METHOD2_FILTER})
469        unittest_utils.assert_strict_equal(
470            self, filters, frozenset([uc.FLAT_METHOD_FILTER]))
471        filters = self.tr._flatten_test_filters({uc.METHOD_FILTER,
472                                                 METHOD2_FILTER,
473                                                 CLASS2_METHOD_FILTER,})
474        unittest_utils.assert_strict_equal(
475            self, filters, frozenset([uc.FLAT_METHOD_FILTER,
476                                      CLASS2_METHOD_FILTER]))
477
478    def test_flatten_test_infos(self):
479        """Test _flatten_test_infos method."""
480        # No Flattening
481        test_infos = self.tr._flatten_test_infos({uc.MODULE_INFO})
482        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
483                                                  {uc.MODULE_INFO})
484
485        test_infos = self.tr._flatten_test_infos([uc.MODULE_INFO, MODULE2_INFO])
486        unittest_utils.assert_equal_testinfo_sets(
487            self, test_infos, {uc.MODULE_INFO, MODULE2_INFO})
488
489        test_infos = self.tr._flatten_test_infos({CLASS1_INFO})
490        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
491                                                  {CLASS1_INFO})
492
493        test_infos = self.tr._flatten_test_infos({uc.INT_INFO})
494        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
495                                                  {uc.INT_INFO})
496
497        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO})
498        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
499                                                  {uc.METHOD_INFO})
500
501        # Flattening
502        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, CLASS2_INFO})
503        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
504                                                  {FLAT_CLASS_INFO})
505
506        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, uc.INT_INFO,
507                                                  CLASS2_INFO})
508        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
509                                                  {uc.INT_INFO,
510                                                   FLAT_CLASS_INFO})
511
512        test_infos = self.tr._flatten_test_infos({CLASS1_INFO, uc.MODULE_INFO,
513                                                  CLASS2_INFO})
514        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
515                                                  {CLASS1_CLASS2_MODULE_INFO})
516
517        test_infos = self.tr._flatten_test_infos({MODULE2_INFO, uc.INT_INFO,
518                                                  CLASS1_INFO, CLASS2_INFO,
519                                                  uc.GTF_INT_INFO})
520        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
521                                                  {uc.INT_INFO, uc.GTF_INT_INFO,
522                                                   FLAT_CLASS_INFO,
523                                                   MODULE2_INFO})
524
525        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO,
526                                                  CLASS2_METHOD_INFO})
527        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
528                                                  {METHOD_AND_CLASS2_METHOD})
529
530        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO, METHOD2_INFO,
531                                                  CLASS2_METHOD_INFO})
532        unittest_utils.assert_equal_testinfo_sets(
533            self, test_infos, {METHOD_METHOD2_AND_CLASS2_METHOD})
534        test_infos = self.tr._flatten_test_infos({uc.METHOD_INFO, METHOD2_INFO,
535                                                  CLASS2_METHOD_INFO,
536                                                  MODULE2_INFO,
537                                                  uc.INT_INFO})
538        unittest_utils.assert_equal_testinfo_sets(
539            self, test_infos, {uc.INT_INFO, MODULE2_INFO,
540                               METHOD_METHOD2_AND_CLASS2_METHOD})
541
542        test_infos = self.tr._flatten_test_infos({CLASS3_INFO, CLASS4_INFO})
543        unittest_utils.assert_equal_testinfo_sets(self, test_infos,
544                                                  {FLAT2_CLASS_INFO})
545
546    def test_create_test_args(self):
547        """Test _create_test_args method."""
548        # Only compile '--skip-loading-config-jar' in TF if it's not
549        # INTEGRATION finder or the finder property isn't set.
550        args = self.tr._create_test_args([MOD_INFO])
551        self.assertTrue(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
552
553        args = self.tr._create_test_args([INT_INFO])
554        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
555
556        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER])
557        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
558
559        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER, INT_INFO])
560        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
561
562        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER])
563        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
564
565        args = self.tr._create_test_args([MOD_INFO_NO_TEST_FINDER, INT_INFO, MOD_INFO])
566        self.assertFalse(constants.TF_SKIP_LOADING_CONFIG_JAR in args)
567
568
569    @mock.patch('os.environ.get', return_value=None)
570    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
571    @mock.patch('atest_utils.get_result_server_args')
572    def test_generate_run_commands_with_tf_template(self, mock_resultargs, mock_mertrics, _):
573        """Test generate_run_command method."""
574        tf_tmplate_key1 = 'tf_tmplate_key1'
575        tf_tmplate_val1 = 'tf_tmplate_val1'
576        tf_tmplate_key2 = 'tf_tmplate_key2'
577        tf_tmplate_val2 = 'tf_tmplate_val2'
578        # Testing with only one tradefed template command
579        mock_resultargs.return_value = []
580        mock_mertrics.return_value = ''
581        extra_args = {constants.TF_TEMPLATE:
582                          ['{}={}'.format(tf_tmplate_key1,
583                                          tf_tmplate_val1)]}
584        unittest_utils.assert_strict_equal(
585            self,
586            self.tr.generate_run_commands([], extra_args),
587            [RUN_CMD.format(
588                metrics='',
589                serial='',
590                tf_customize_template=
591                '--template:map {}={} ').format(tf_tmplate_key1,
592                                                tf_tmplate_val1)])
593        # Testing with two tradefed template commands
594        extra_args = {constants.TF_TEMPLATE:
595                          ['{}={}'.format(tf_tmplate_key1,
596                                          tf_tmplate_val1),
597                           '{}={}'.format(tf_tmplate_key2,
598                                          tf_tmplate_val2)]}
599        unittest_utils.assert_strict_equal(
600            self,
601            self.tr.generate_run_commands([], extra_args),
602            [RUN_CMD.format(
603                metrics='',
604                serial='',
605                tf_customize_template=
606                '--template:map {}={} --template:map {}={} ').format(
607                    tf_tmplate_key1,
608                    tf_tmplate_val1,
609                    tf_tmplate_key2,
610                    tf_tmplate_val2)])
611
612    @mock.patch('os.environ.get', return_value=None)
613    @mock.patch.object(atf_tr.AtestTradefedTestRunner, '_generate_metrics_folder')
614    @mock.patch('atest_utils.get_result_server_args')
615    def test_generate_run_commands_collect_tests_only(self,
616                                                      mock_resultargs,
617                                                      mock_mertrics, _):
618        """Test generate_run_command method."""
619        # Testing  without collect-tests-only
620        mock_resultargs.return_value = []
621        mock_mertrics.return_value = ''
622        extra_args = {}
623        unittest_utils.assert_strict_equal(
624            self,
625            self.tr.generate_run_commands([], extra_args),
626            [RUN_CMD.format(
627                metrics='',
628                serial='',
629                tf_customize_template='')])
630        # Testing  with collect-tests-only
631        mock_resultargs.return_value = []
632        mock_mertrics.return_value = ''
633        extra_args = {constants.COLLECT_TESTS_ONLY: True}
634        unittest_utils.assert_strict_equal(
635            self,
636            self.tr.generate_run_commands([], extra_args),
637            [RUN_CMD.format(
638                metrics='',
639                serial=' --collect-tests-only',
640                tf_customize_template='')])
641
642if __name__ == '__main__':
643    unittest.main()
644