1#!/usr/bin/python
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import mock
7import unittest
8
9import common
10from autotest_lib.client.common_lib import error
11from autotest_lib.server.hosts import base_label_unittest, factory
12
13
14class MockHost(object):
15    """Mock host object with no side effects."""
16    def __init__(self, hostname, **args):
17        self._init_args = args
18        self._init_args['hostname'] = hostname
19
20
21    def job_start(self):
22        """Only method called by factory."""
23        pass
24
25
26class MockConnectivity(object):
27    """Mock connectivity object with no side effects."""
28    def __init__(self, hostname, **args):
29        pass
30
31
32    def close(self):
33        """Only method called by factory."""
34        pass
35
36
37def _gen_mock_host(name, check_host=False):
38    """Create an identifiable mock host closs.
39    """
40    return type('mock_host_%s' % name, (MockHost,), {
41        '_host_cls_name': name,
42        'check_host': staticmethod(lambda host, timeout=None: check_host)
43    })
44
45
46def _gen_mock_conn(name):
47    """Create an identifiable mock connectivity class.
48    """
49    return type('mock_conn_%s' % name, (MockConnectivity,),
50                {'_conn_cls_name': name})
51
52
53def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
54    """Generate a machine dictionary with the specified parameters.
55
56    @param hostname: hostname of machine
57    @param labels: list of host labels
58    @param attributes: dict of host attributes
59
60    @return: machine dict with mocked AFE Host object and fake AfeStore.
61    """
62    afe_host = base_label_unittest.MockAFEHost(labels, attributes)
63    return {'hostname': hostname,
64            'afe_host': afe_host,
65            'host_info_store': mock.sentinel.dummy}
66
67
68class CreateHostUnittests(unittest.TestCase):
69    """Tests for create_host function."""
70
71    def setUp(self):
72        """Prevent use of real Host and connectivity objects due to potential
73        side effects.
74        """
75        self._orig_ssh_engine = factory.SSH_ENGINE
76        self._orig_types = factory.host_types
77        self._orig_dict = factory.OS_HOST_DICT
78        self._orig_cros_host = factory.cros_host.CrosHost
79        self._orig_local_host = factory.local_host.LocalHost
80        self._orig_ssh_host = factory.ssh_host.SSHHost
81
82        self.host_types = factory.host_types = []
83        self.os_host_dict = factory.OS_HOST_DICT = {}
84        factory.cros_host.CrosHost = _gen_mock_host('cros_host')
85        factory.local_host.LocalHost = _gen_mock_conn('local')
86        factory.ssh_host.SSHHost = _gen_mock_conn('ssh')
87
88
89    def tearDown(self):
90        """Clean up mocks."""
91        factory.SSH_ENGINE = self._orig_ssh_engine
92        factory.host_types = self._orig_types
93        factory.OS_HOST_DICT = self._orig_dict
94        factory.cros_host.CrosHost = self._orig_cros_host
95        factory.local_host.LocalHost = self._orig_local_host
96        factory.ssh_host.SSHHost = self._orig_ssh_host
97
98
99    def test_use_specified(self):
100        """Confirm that the specified host and connectivity classes are used."""
101        machine = _gen_machine_dict()
102        host_obj = factory.create_host(
103                machine,
104                _gen_mock_host('specified'),
105                _gen_mock_conn('specified')
106        )
107        self.assertEqual(host_obj._host_cls_name, 'specified')
108        self.assertEqual(host_obj._conn_cls_name, 'specified')
109
110
111    def test_detect_host_by_os_label(self):
112        """Confirm that the host object is selected by the os label.
113        """
114        machine = _gen_machine_dict(labels=['os:foo'])
115        self.os_host_dict['foo'] = _gen_mock_host('foo')
116        host_obj = factory.create_host(machine)
117        self.assertEqual(host_obj._host_cls_name, 'foo')
118
119
120    def test_detect_host_by_os_type_attribute(self):
121        """Confirm that the host object is selected by the os_type attribute
122        and that the os_type attribute is preferred over the os label.
123        """
124        machine = _gen_machine_dict(labels=['os:foo'],
125                                         attributes={'os_type': 'bar'})
126        self.os_host_dict['foo'] = _gen_mock_host('foo')
127        self.os_host_dict['bar'] = _gen_mock_host('bar')
128        host_obj = factory.create_host(machine)
129        self.assertEqual(host_obj._host_cls_name, 'bar')
130
131
132    def test_detect_host_by_check_host(self):
133        """Confirm check_host logic chooses a host object when label/attribute
134        detection fails.
135        """
136        machine = _gen_machine_dict()
137        self.host_types.append(_gen_mock_host('first', check_host=False))
138        self.host_types.append(_gen_mock_host('second', check_host=True))
139        self.host_types.append(_gen_mock_host('third', check_host=False))
140        host_obj = factory.create_host(machine)
141        self.assertEqual(host_obj._host_cls_name, 'second')
142
143
144    def test_detect_host_fallback_to_cros_host(self):
145        """Confirm fallback to CrosHost when all other detection fails.
146        """
147        machine = _gen_machine_dict()
148        host_obj = factory.create_host(machine)
149        self.assertEqual(host_obj._host_cls_name, 'cros_host')
150
151
152    def test_choose_connectivity_local(self):
153        """Confirm local connectivity class used when hostname is localhost.
154        """
155        machine = _gen_machine_dict(hostname='localhost')
156        host_obj = factory.create_host(machine)
157        self.assertEqual(host_obj._conn_cls_name, 'local')
158
159
160    def test_choose_connectivity_ssh(self):
161        """Confirm ssh connectivity class used when configured and hostname
162        is not localhost.
163        """
164        factory.SSH_ENGINE = 'raw_ssh'
165        machine = _gen_machine_dict(hostname='somehost')
166        host_obj = factory.create_host(machine)
167        self.assertEqual(host_obj._conn_cls_name, 'ssh')
168
169
170    def test_choose_connectivity_unsupported(self):
171        """Confirm exception when configured for unsupported ssh engine.
172        """
173        factory.SSH_ENGINE = 'unsupported'
174        machine = _gen_machine_dict(hostname='somehost')
175        with self.assertRaises(error.AutoservError):
176            factory.create_host(machine)
177
178
179    def test_argument_passthrough(self):
180        """Confirm that detected and specified arguments are passed through to
181        the host object.
182        """
183        machine = _gen_machine_dict(hostname='localhost')
184        host_obj = factory.create_host(machine, foo='bar')
185        self.assertEqual(host_obj._init_args['hostname'], 'localhost')
186        self.assertTrue('afe_host' in host_obj._init_args)
187        self.assertTrue('host_info_store' in host_obj._init_args)
188        self.assertEqual(host_obj._init_args['foo'], 'bar')
189
190
191    def test_global_ssh_params(self):
192        """Confirm passing of ssh parameters set as globals.
193        """
194        factory.ssh_user = 'foo'
195        factory.ssh_pass = 'bar'
196        factory.ssh_port = 1
197        factory.ssh_verbosity_flag = 'baz'
198        factory.ssh_options = 'zip'
199        machine = _gen_machine_dict()
200        try:
201            host_obj = factory.create_host(machine)
202            self.assertEqual(host_obj._init_args['user'], 'foo')
203            self.assertEqual(host_obj._init_args['password'], 'bar')
204            self.assertEqual(host_obj._init_args['port'], 1)
205            self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
206            self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
207        finally:
208            del factory.ssh_user
209            del factory.ssh_pass
210            del factory.ssh_port
211            del factory.ssh_verbosity_flag
212            del factory.ssh_options
213
214
215    def test_host_attribute_ssh_params(self):
216        """Confirm passing of ssh parameters from host attributes.
217        """
218        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
219                                                'ssh_port': 100,
220                                                'ssh_verbosity_flag': 'verb',
221                                                'ssh_options': 'options'})
222        host_obj = factory.create_host(machine)
223        self.assertEqual(host_obj._init_args['user'], 'somebody')
224        self.assertEqual(host_obj._init_args['port'], 100)
225        self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
226        self.assertEqual(host_obj._init_args['ssh_options'], 'options')
227
228
229class CreateTestbedUnittests(unittest.TestCase):
230    """Tests for create_testbed function."""
231
232    def setUp(self):
233        """Mock out TestBed class to eliminate side effects.
234        """
235        self._orig_testbed = factory.testbed.TestBed
236        factory.testbed.TestBed = _gen_mock_host('testbed')
237
238
239    def tearDown(self):
240        """Clean up mock.
241        """
242        factory.testbed.TestBed = self._orig_testbed
243
244
245    def test_argument_passthrough(self):
246        """Confirm that detected and specified arguments are passed through to
247        the testbed object.
248        """
249        machine = _gen_machine_dict(hostname='localhost')
250        testbed_obj = factory.create_testbed(machine, foo='bar')
251        self.assertEqual(testbed_obj._init_args['hostname'], 'localhost')
252        self.assertTrue('afe_host' in testbed_obj._init_args)
253        self.assertTrue('host_info_store' in testbed_obj._init_args)
254        self.assertEqual(testbed_obj._init_args['foo'], 'bar')
255
256
257    def test_global_ssh_params(self):
258        """Confirm passing of ssh parameters set as globals.
259        """
260        factory.ssh_user = 'foo'
261        factory.ssh_pass = 'bar'
262        factory.ssh_port = 1
263        factory.ssh_verbosity_flag = 'baz'
264        factory.ssh_options = 'zip'
265        machine = _gen_machine_dict()
266        try:
267            testbed_obj = factory.create_testbed(machine)
268            self.assertEqual(testbed_obj._init_args['user'], 'foo')
269            self.assertEqual(testbed_obj._init_args['password'], 'bar')
270            self.assertEqual(testbed_obj._init_args['port'], 1)
271            self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'],
272                             'baz')
273            self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip')
274        finally:
275            del factory.ssh_user
276            del factory.ssh_pass
277            del factory.ssh_port
278            del factory.ssh_verbosity_flag
279            del factory.ssh_options
280
281
282    def test_host_attribute_ssh_params(self):
283        """Confirm passing of ssh parameters from host attributes.
284        """
285        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
286                                                'ssh_port': 100,
287                                                'ssh_verbosity_flag': 'verb',
288                                                'ssh_options': 'options'})
289        testbed_obj = factory.create_testbed(machine)
290        self.assertEqual(testbed_obj._init_args['user'], 'somebody')
291        self.assertEqual(testbed_obj._init_args['port'], 100)
292        self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb')
293        self.assertEqual(testbed_obj._init_args['ssh_options'], 'options')
294
295
296if __name__ == '__main__':
297    unittest.main()
298
299