1#!/usr/bin/python 2# Copyright 2016 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6import mock 7import unittest 8 9import common 10from autotest_lib.client.common_lib import error 11from autotest_lib.server.hosts import base_label_unittest, factory 12 13 14class MockHost(object): 15 """Mock host object with no side effects.""" 16 def __init__(self, hostname, **args): 17 self._init_args = args 18 self._init_args['hostname'] = hostname 19 20 21 def job_start(self): 22 """Only method called by factory.""" 23 pass 24 25 26class MockConnectivity(object): 27 """Mock connectivity object with no side effects.""" 28 def __init__(self, hostname, **args): 29 pass 30 31 32 def close(self): 33 """Only method called by factory.""" 34 pass 35 36 37def _gen_mock_host(name, check_host=False): 38 """Create an identifiable mock host closs. 39 """ 40 return type('mock_host_%s' % name, (MockHost,), { 41 '_host_cls_name': name, 42 'check_host': staticmethod(lambda host, timeout=None: check_host) 43 }) 44 45 46def _gen_mock_conn(name): 47 """Create an identifiable mock connectivity class. 48 """ 49 return type('mock_conn_%s' % name, (MockConnectivity,), 50 {'_conn_cls_name': name}) 51 52 53def _gen_machine_dict(hostname='localhost', labels=[], attributes={}): 54 """Generate a machine dictionary with the specified parameters. 55 56 @param hostname: hostname of machine 57 @param labels: list of host labels 58 @param attributes: dict of host attributes 59 60 @return: machine dict with mocked AFE Host object and fake AfeStore. 61 """ 62 afe_host = base_label_unittest.MockAFEHost(labels, attributes) 63 return {'hostname': hostname, 64 'afe_host': afe_host, 65 'host_info_store': mock.sentinel.dummy} 66 67 68class CreateHostUnittests(unittest.TestCase): 69 """Tests for create_host function.""" 70 71 def setUp(self): 72 """Prevent use of real Host and connectivity objects due to potential 73 side effects. 74 """ 75 self._orig_ssh_engine = factory.SSH_ENGINE 76 self._orig_types = factory.host_types 77 self._orig_dict = factory.OS_HOST_DICT 78 self._orig_cros_host = factory.cros_host.CrosHost 79 self._orig_local_host = factory.local_host.LocalHost 80 self._orig_ssh_host = factory.ssh_host.SSHHost 81 82 self.host_types = factory.host_types = [] 83 self.os_host_dict = factory.OS_HOST_DICT = {} 84 factory.cros_host.CrosHost = _gen_mock_host('cros_host') 85 factory.local_host.LocalHost = _gen_mock_conn('local') 86 factory.ssh_host.SSHHost = _gen_mock_conn('ssh') 87 88 89 def tearDown(self): 90 """Clean up mocks.""" 91 factory.SSH_ENGINE = self._orig_ssh_engine 92 factory.host_types = self._orig_types 93 factory.OS_HOST_DICT = self._orig_dict 94 factory.cros_host.CrosHost = self._orig_cros_host 95 factory.local_host.LocalHost = self._orig_local_host 96 factory.ssh_host.SSHHost = self._orig_ssh_host 97 98 99 def test_use_specified(self): 100 """Confirm that the specified host and connectivity classes are used.""" 101 machine = _gen_machine_dict() 102 host_obj = factory.create_host( 103 machine, 104 _gen_mock_host('specified'), 105 _gen_mock_conn('specified') 106 ) 107 self.assertEqual(host_obj._host_cls_name, 'specified') 108 self.assertEqual(host_obj._conn_cls_name, 'specified') 109 110 111 def test_detect_host_by_os_label(self): 112 """Confirm that the host object is selected by the os label. 113 """ 114 machine = _gen_machine_dict(labels=['os:foo']) 115 self.os_host_dict['foo'] = _gen_mock_host('foo') 116 host_obj = factory.create_host(machine) 117 self.assertEqual(host_obj._host_cls_name, 'foo') 118 119 120 def test_detect_host_by_os_type_attribute(self): 121 """Confirm that the host object is selected by the os_type attribute 122 and that the os_type attribute is preferred over the os label. 123 """ 124 machine = _gen_machine_dict(labels=['os:foo'], 125 attributes={'os_type': 'bar'}) 126 self.os_host_dict['foo'] = _gen_mock_host('foo') 127 self.os_host_dict['bar'] = _gen_mock_host('bar') 128 host_obj = factory.create_host(machine) 129 self.assertEqual(host_obj._host_cls_name, 'bar') 130 131 132 def test_detect_host_by_check_host(self): 133 """Confirm check_host logic chooses a host object when label/attribute 134 detection fails. 135 """ 136 machine = _gen_machine_dict() 137 self.host_types.append(_gen_mock_host('first', check_host=False)) 138 self.host_types.append(_gen_mock_host('second', check_host=True)) 139 self.host_types.append(_gen_mock_host('third', check_host=False)) 140 host_obj = factory.create_host(machine) 141 self.assertEqual(host_obj._host_cls_name, 'second') 142 143 144 def test_detect_host_fallback_to_cros_host(self): 145 """Confirm fallback to CrosHost when all other detection fails. 146 """ 147 machine = _gen_machine_dict() 148 host_obj = factory.create_host(machine) 149 self.assertEqual(host_obj._host_cls_name, 'cros_host') 150 151 152 def test_choose_connectivity_local(self): 153 """Confirm local connectivity class used when hostname is localhost. 154 """ 155 machine = _gen_machine_dict(hostname='localhost') 156 host_obj = factory.create_host(machine) 157 self.assertEqual(host_obj._conn_cls_name, 'local') 158 159 160 def test_choose_connectivity_ssh(self): 161 """Confirm ssh connectivity class used when configured and hostname 162 is not localhost. 163 """ 164 factory.SSH_ENGINE = 'raw_ssh' 165 machine = _gen_machine_dict(hostname='somehost') 166 host_obj = factory.create_host(machine) 167 self.assertEqual(host_obj._conn_cls_name, 'ssh') 168 169 170 def test_choose_connectivity_unsupported(self): 171 """Confirm exception when configured for unsupported ssh engine. 172 """ 173 factory.SSH_ENGINE = 'unsupported' 174 machine = _gen_machine_dict(hostname='somehost') 175 with self.assertRaises(error.AutoservError): 176 factory.create_host(machine) 177 178 179 def test_argument_passthrough(self): 180 """Confirm that detected and specified arguments are passed through to 181 the host object. 182 """ 183 machine = _gen_machine_dict(hostname='localhost') 184 host_obj = factory.create_host(machine, foo='bar') 185 self.assertEqual(host_obj._init_args['hostname'], 'localhost') 186 self.assertTrue('afe_host' in host_obj._init_args) 187 self.assertTrue('host_info_store' in host_obj._init_args) 188 self.assertEqual(host_obj._init_args['foo'], 'bar') 189 190 191 def test_global_ssh_params(self): 192 """Confirm passing of ssh parameters set as globals. 193 """ 194 factory.ssh_user = 'foo' 195 factory.ssh_pass = 'bar' 196 factory.ssh_port = 1 197 factory.ssh_verbosity_flag = 'baz' 198 factory.ssh_options = 'zip' 199 machine = _gen_machine_dict() 200 try: 201 host_obj = factory.create_host(machine) 202 self.assertEqual(host_obj._init_args['user'], 'foo') 203 self.assertEqual(host_obj._init_args['password'], 'bar') 204 self.assertEqual(host_obj._init_args['port'], 1) 205 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz') 206 self.assertEqual(host_obj._init_args['ssh_options'], 'zip') 207 finally: 208 del factory.ssh_user 209 del factory.ssh_pass 210 del factory.ssh_port 211 del factory.ssh_verbosity_flag 212 del factory.ssh_options 213 214 215 def test_host_attribute_ssh_params(self): 216 """Confirm passing of ssh parameters from host attributes. 217 """ 218 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 219 'ssh_port': 100, 220 'ssh_verbosity_flag': 'verb', 221 'ssh_options': 'options'}) 222 host_obj = factory.create_host(machine) 223 self.assertEqual(host_obj._init_args['user'], 'somebody') 224 self.assertEqual(host_obj._init_args['port'], 100) 225 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb') 226 self.assertEqual(host_obj._init_args['ssh_options'], 'options') 227 228 229class CreateTestbedUnittests(unittest.TestCase): 230 """Tests for create_testbed function.""" 231 232 def setUp(self): 233 """Mock out TestBed class to eliminate side effects. 234 """ 235 self._orig_testbed = factory.testbed.TestBed 236 factory.testbed.TestBed = _gen_mock_host('testbed') 237 238 239 def tearDown(self): 240 """Clean up mock. 241 """ 242 factory.testbed.TestBed = self._orig_testbed 243 244 245 def test_argument_passthrough(self): 246 """Confirm that detected and specified arguments are passed through to 247 the testbed object. 248 """ 249 machine = _gen_machine_dict(hostname='localhost') 250 testbed_obj = factory.create_testbed(machine, foo='bar') 251 self.assertEqual(testbed_obj._init_args['hostname'], 'localhost') 252 self.assertTrue('afe_host' in testbed_obj._init_args) 253 self.assertTrue('host_info_store' in testbed_obj._init_args) 254 self.assertEqual(testbed_obj._init_args['foo'], 'bar') 255 256 257 def test_global_ssh_params(self): 258 """Confirm passing of ssh parameters set as globals. 259 """ 260 factory.ssh_user = 'foo' 261 factory.ssh_pass = 'bar' 262 factory.ssh_port = 1 263 factory.ssh_verbosity_flag = 'baz' 264 factory.ssh_options = 'zip' 265 machine = _gen_machine_dict() 266 try: 267 testbed_obj = factory.create_testbed(machine) 268 self.assertEqual(testbed_obj._init_args['user'], 'foo') 269 self.assertEqual(testbed_obj._init_args['password'], 'bar') 270 self.assertEqual(testbed_obj._init_args['port'], 1) 271 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 272 'baz') 273 self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip') 274 finally: 275 del factory.ssh_user 276 del factory.ssh_pass 277 del factory.ssh_port 278 del factory.ssh_verbosity_flag 279 del factory.ssh_options 280 281 282 def test_host_attribute_ssh_params(self): 283 """Confirm passing of ssh parameters from host attributes. 284 """ 285 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 286 'ssh_port': 100, 287 'ssh_verbosity_flag': 'verb', 288 'ssh_options': 'options'}) 289 testbed_obj = factory.create_testbed(machine) 290 self.assertEqual(testbed_obj._init_args['user'], 'somebody') 291 self.assertEqual(testbed_obj._init_args['port'], 100) 292 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb') 293 self.assertEqual(testbed_obj._init_args['ssh_options'], 'options') 294 295 296if __name__ == '__main__': 297 unittest.main() 298 299