#!/usr/bin/python # Copyright 2016 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. import mock import unittest import common from autotest_lib.client.common_lib import error from autotest_lib.server.hosts import base_label_unittest, factory class MockHost(object): """Mock host object with no side effects.""" def __init__(self, hostname, **args): self._init_args = args self._init_args['hostname'] = hostname def job_start(self): """Only method called by factory.""" pass class MockConnectivity(object): """Mock connectivity object with no side effects.""" def __init__(self, hostname, **args): pass def close(self): """Only method called by factory.""" pass def _gen_mock_host(name, check_host=False): """Create an identifiable mock host closs. """ return type('mock_host_%s' % name, (MockHost,), { '_host_cls_name': name, 'check_host': staticmethod(lambda host, timeout=None: check_host) }) def _gen_mock_conn(name): """Create an identifiable mock connectivity class. """ return type('mock_conn_%s' % name, (MockConnectivity,), {'_conn_cls_name': name}) def _gen_machine_dict(hostname='localhost', labels=[], attributes={}): """Generate a machine dictionary with the specified parameters. @param hostname: hostname of machine @param labels: list of host labels @param attributes: dict of host attributes @return: machine dict with mocked AFE Host object and fake AfeStore. """ afe_host = base_label_unittest.MockAFEHost(labels, attributes) return {'hostname': hostname, 'afe_host': afe_host, 'host_info_store': mock.sentinel.dummy} class CreateHostUnittests(unittest.TestCase): """Tests for create_host function.""" def setUp(self): """Prevent use of real Host and connectivity objects due to potential side effects. """ self._orig_ssh_engine = factory.SSH_ENGINE self._orig_types = factory.host_types self._orig_dict = factory.OS_HOST_DICT self._orig_cros_host = factory.cros_host.CrosHost self._orig_local_host = factory.local_host.LocalHost self._orig_ssh_host = factory.ssh_host.SSHHost self.host_types = factory.host_types = [] self.os_host_dict = factory.OS_HOST_DICT = {} factory.cros_host.CrosHost = _gen_mock_host('cros_host') factory.local_host.LocalHost = _gen_mock_conn('local') factory.ssh_host.SSHHost = _gen_mock_conn('ssh') def tearDown(self): """Clean up mocks.""" factory.SSH_ENGINE = self._orig_ssh_engine factory.host_types = self._orig_types factory.OS_HOST_DICT = self._orig_dict factory.cros_host.CrosHost = self._orig_cros_host factory.local_host.LocalHost = self._orig_local_host factory.ssh_host.SSHHost = self._orig_ssh_host def test_use_specified(self): """Confirm that the specified host and connectivity classes are used.""" machine = _gen_machine_dict() host_obj = factory.create_host( machine, _gen_mock_host('specified'), _gen_mock_conn('specified') ) self.assertEqual(host_obj._host_cls_name, 'specified') self.assertEqual(host_obj._conn_cls_name, 'specified') def test_detect_host_by_os_label(self): """Confirm that the host object is selected by the os label. """ machine = _gen_machine_dict(labels=['os:foo']) self.os_host_dict['foo'] = _gen_mock_host('foo') host_obj = factory.create_host(machine) self.assertEqual(host_obj._host_cls_name, 'foo') def test_detect_host_by_os_type_attribute(self): """Confirm that the host object is selected by the os_type attribute and that the os_type attribute is preferred over the os label. """ machine = _gen_machine_dict(labels=['os:foo'], attributes={'os_type': 'bar'}) self.os_host_dict['foo'] = _gen_mock_host('foo') self.os_host_dict['bar'] = _gen_mock_host('bar') host_obj = factory.create_host(machine) self.assertEqual(host_obj._host_cls_name, 'bar') def test_detect_host_by_check_host(self): """Confirm check_host logic chooses a host object when label/attribute detection fails. """ machine = _gen_machine_dict() self.host_types.append(_gen_mock_host('first', check_host=False)) self.host_types.append(_gen_mock_host('second', check_host=True)) self.host_types.append(_gen_mock_host('third', check_host=False)) host_obj = factory.create_host(machine) self.assertEqual(host_obj._host_cls_name, 'second') def test_detect_host_fallback_to_cros_host(self): """Confirm fallback to CrosHost when all other detection fails. """ machine = _gen_machine_dict() host_obj = factory.create_host(machine) self.assertEqual(host_obj._host_cls_name, 'cros_host') def test_choose_connectivity_local(self): """Confirm local connectivity class used when hostname is localhost. """ machine = _gen_machine_dict(hostname='localhost') host_obj = factory.create_host(machine) self.assertEqual(host_obj._conn_cls_name, 'local') def test_choose_connectivity_ssh(self): """Confirm ssh connectivity class used when configured and hostname is not localhost. """ factory.SSH_ENGINE = 'raw_ssh' machine = _gen_machine_dict(hostname='somehost') host_obj = factory.create_host(machine) self.assertEqual(host_obj._conn_cls_name, 'ssh') def test_choose_connectivity_unsupported(self): """Confirm exception when configured for unsupported ssh engine. """ factory.SSH_ENGINE = 'unsupported' machine = _gen_machine_dict(hostname='somehost') with self.assertRaises(error.AutoservError): factory.create_host(machine) def test_argument_passthrough(self): """Confirm that detected and specified arguments are passed through to the host object. """ machine = _gen_machine_dict(hostname='localhost') host_obj = factory.create_host(machine, foo='bar') self.assertEqual(host_obj._init_args['hostname'], 'localhost') self.assertTrue('afe_host' in host_obj._init_args) self.assertTrue('host_info_store' in host_obj._init_args) self.assertEqual(host_obj._init_args['foo'], 'bar') def test_global_ssh_params(self): """Confirm passing of ssh parameters set as globals. """ factory.ssh_user = 'foo' factory.ssh_pass = 'bar' factory.ssh_port = 1 factory.ssh_verbosity_flag = 'baz' factory.ssh_options = 'zip' machine = _gen_machine_dict() try: host_obj = factory.create_host(machine) self.assertEqual(host_obj._init_args['user'], 'foo') self.assertEqual(host_obj._init_args['password'], 'bar') self.assertEqual(host_obj._init_args['port'], 1) self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz') self.assertEqual(host_obj._init_args['ssh_options'], 'zip') finally: del factory.ssh_user del factory.ssh_pass del factory.ssh_port del factory.ssh_verbosity_flag del factory.ssh_options def test_host_attribute_ssh_params(self): """Confirm passing of ssh parameters from host attributes. """ machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 'ssh_port': 100, 'ssh_verbosity_flag': 'verb', 'ssh_options': 'options'}) host_obj = factory.create_host(machine) self.assertEqual(host_obj._init_args['user'], 'somebody') self.assertEqual(host_obj._init_args['port'], 100) self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb') self.assertEqual(host_obj._init_args['ssh_options'], 'options') class CreateTestbedUnittests(unittest.TestCase): """Tests for create_testbed function.""" def setUp(self): """Mock out TestBed class to eliminate side effects. """ self._orig_testbed = factory.testbed.TestBed factory.testbed.TestBed = _gen_mock_host('testbed') def tearDown(self): """Clean up mock. """ factory.testbed.TestBed = self._orig_testbed def test_argument_passthrough(self): """Confirm that detected and specified arguments are passed through to the testbed object. """ machine = _gen_machine_dict(hostname='localhost') testbed_obj = factory.create_testbed(machine, foo='bar') self.assertEqual(testbed_obj._init_args['hostname'], 'localhost') self.assertTrue('afe_host' in testbed_obj._init_args) self.assertTrue('host_info_store' in testbed_obj._init_args) self.assertEqual(testbed_obj._init_args['foo'], 'bar') def test_global_ssh_params(self): """Confirm passing of ssh parameters set as globals. """ factory.ssh_user = 'foo' factory.ssh_pass = 'bar' factory.ssh_port = 1 factory.ssh_verbosity_flag = 'baz' factory.ssh_options = 'zip' machine = _gen_machine_dict() try: testbed_obj = factory.create_testbed(machine) self.assertEqual(testbed_obj._init_args['user'], 'foo') self.assertEqual(testbed_obj._init_args['password'], 'bar') self.assertEqual(testbed_obj._init_args['port'], 1) self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'baz') self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip') finally: del factory.ssh_user del factory.ssh_pass del factory.ssh_port del factory.ssh_verbosity_flag del factory.ssh_options def test_host_attribute_ssh_params(self): """Confirm passing of ssh parameters from host attributes. """ machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 'ssh_port': 100, 'ssh_verbosity_flag': 'verb', 'ssh_options': 'options'}) testbed_obj = factory.create_testbed(machine) self.assertEqual(testbed_obj._init_args['user'], 'somebody') self.assertEqual(testbed_obj._init_args['port'], 100) self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb') self.assertEqual(testbed_obj._init_args['ssh_options'], 'options') if __name__ == '__main__': unittest.main()