1#!/usr/bin/python
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import unittest
7
8import common
9from autotest_lib.client.common_lib import error
10from autotest_lib.server.hosts import base_label_unittest, factory
11from autotest_lib.server.hosts import host_info
12
13
14class MockHost(object):
15    """Mock host object with no side effects."""
16    def __init__(self, hostname, **args):
17        self._init_args = args
18        self._init_args['hostname'] = hostname
19
20
21    def job_start(self):
22        """Only method called by factory."""
23        pass
24
25
26class MockConnectivity(object):
27    """Mock connectivity object with no side effects."""
28    def __init__(self, hostname, **args):
29        pass
30
31    def run(self, *args, **kwargs):
32        pass
33
34    def close(self):
35        pass
36
37
38def _gen_mock_host(name, check_host=False):
39    """Create an identifiable mock host closs.
40    """
41    return type('mock_host_%s' % name, (MockHost,), {
42        '_host_cls_name': name,
43        'check_host': staticmethod(lambda host, timeout=None: check_host)
44    })
45
46
47def _gen_mock_conn(name):
48    """Create an identifiable mock connectivity class.
49    """
50    return type('mock_conn_%s' % name, (MockConnectivity,),
51                {'_conn_cls_name': name})
52
53
54def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
55    """Generate a machine dictionary with the specified parameters.
56
57    @param hostname: hostname of machine
58    @param labels: list of host labels
59    @param attributes: dict of host attributes
60
61    @return: machine dict with mocked AFE Host object and fake AfeStore.
62    """
63    afe_host = base_label_unittest.MockAFEHost(labels, attributes)
64    store = host_info.InMemoryHostInfoStore()
65    store.commit(host_info.HostInfo(labels, attributes))
66    return {'hostname': hostname,
67            'afe_host': afe_host,
68            'host_info_store': store}
69
70
71class CreateHostUnittests(unittest.TestCase):
72    """Tests for create_host function."""
73
74    def setUp(self):
75        """Prevent use of real Host and connectivity objects due to potential
76        side effects.
77        """
78        self._orig_types = factory.host_types
79        self._orig_dict = factory.OS_HOST_DICT
80        self._orig_cros_host = factory.cros_host.CrosHost
81        self._orig_local_host = factory.local_host.LocalHost
82        self._orig_ssh_host = factory.ssh_host.SSHHost
83
84        self.host_types = factory.host_types = []
85        self.os_host_dict = factory.OS_HOST_DICT = {}
86        factory.cros_host.CrosHost = _gen_mock_host('cros_host')
87        factory.local_host.LocalHost = _gen_mock_conn('local')
88        factory.ssh_host.SSHHost = _gen_mock_conn('ssh')
89
90
91    def tearDown(self):
92        """Clean up mocks."""
93        factory.host_types = self._orig_types
94        factory.OS_HOST_DICT = self._orig_dict
95        factory.cros_host.CrosHost = self._orig_cros_host
96        factory.local_host.LocalHost = self._orig_local_host
97        factory.ssh_host.SSHHost = self._orig_ssh_host
98
99
100    def test_use_specified(self):
101        """Confirm that the specified host class is used."""
102        machine = _gen_machine_dict()
103        host_obj = factory.create_host(
104                machine,
105                _gen_mock_host('specified'),
106        )
107        self.assertEqual(host_obj._host_cls_name, 'specified')
108
109
110    def test_detect_host_by_os_label(self):
111        """Confirm that the host object is selected by the os label.
112        """
113        machine = _gen_machine_dict(labels=['os:foo'])
114        self.os_host_dict['foo'] = _gen_mock_host('foo')
115        host_obj = factory.create_host(machine)
116        self.assertEqual(host_obj._host_cls_name, 'foo')
117
118
119    def test_detect_host_by_os_type_attribute(self):
120        """Confirm that the host object is selected by the os_type attribute
121        and that the os_type attribute is preferred over the os label.
122        """
123        machine = _gen_machine_dict(labels=['os:foo'],
124                                         attributes={'os_type': 'bar'})
125        self.os_host_dict['foo'] = _gen_mock_host('foo')
126        self.os_host_dict['bar'] = _gen_mock_host('bar')
127        host_obj = factory.create_host(machine)
128        self.assertEqual(host_obj._host_cls_name, 'bar')
129
130
131    def test_detect_host_by_check_host(self):
132        """Confirm check_host logic chooses a host object when label/attribute
133        detection fails.
134        """
135        machine = _gen_machine_dict()
136        self.host_types.append(_gen_mock_host('first', check_host=False))
137        self.host_types.append(_gen_mock_host('second', check_host=True))
138        self.host_types.append(_gen_mock_host('third', check_host=False))
139        host_obj = factory.create_host(machine)
140        self.assertEqual(host_obj._host_cls_name, 'second')
141
142
143    def test_detect_host_fallback_to_cros_host(self):
144        """Confirm fallback to CrosHost when all other detection fails.
145        """
146        machine = _gen_machine_dict()
147        host_obj = factory.create_host(machine)
148        self.assertEqual(host_obj._host_cls_name, 'cros_host')
149
150
151    def test_choose_connectivity_local(self):
152        """Confirm local connectivity class used when hostname is localhost.
153        """
154        machine = _gen_machine_dict(hostname='localhost')
155        host_obj = factory.create_host(machine)
156        self.assertEqual(host_obj._conn_cls_name, 'local')
157
158
159    def test_choose_connectivity_ssh(self):
160        """Confirm ssh connectivity class used when configured and hostname
161        is not localhost.
162        """
163        machine = _gen_machine_dict(hostname='somehost')
164        host_obj = factory.create_host(machine)
165        self.assertEqual(host_obj._conn_cls_name, 'ssh')
166
167
168    def test_argument_passthrough(self):
169        """Confirm that detected and specified arguments are passed through to
170        the host object.
171        """
172        machine = _gen_machine_dict(hostname='localhost')
173        host_obj = factory.create_host(machine, foo='bar')
174        self.assertEqual(host_obj._init_args['hostname'], 'localhost')
175        self.assertTrue('afe_host' in host_obj._init_args)
176        self.assertTrue('host_info_store' in host_obj._init_args)
177        self.assertEqual(host_obj._init_args['foo'], 'bar')
178
179
180    def test_global_ssh_params(self):
181        """Confirm passing of ssh parameters set as globals.
182        """
183        factory.ssh_user = 'foo'
184        factory.ssh_pass = 'bar'
185        factory.ssh_port = 1
186        factory.ssh_verbosity_flag = 'baz'
187        factory.ssh_options = 'zip'
188        machine = _gen_machine_dict()
189        try:
190            host_obj = factory.create_host(machine)
191            self.assertEqual(host_obj._init_args['user'], 'foo')
192            self.assertEqual(host_obj._init_args['password'], 'bar')
193            self.assertEqual(host_obj._init_args['port'], 1)
194            self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
195            self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
196        finally:
197            del factory.ssh_user
198            del factory.ssh_pass
199            del factory.ssh_port
200            del factory.ssh_verbosity_flag
201            del factory.ssh_options
202
203
204    def test_host_attribute_ssh_params(self):
205        """Confirm passing of ssh parameters from host attributes.
206        """
207        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
208                                                'ssh_port': 100,
209                                                'ssh_verbosity_flag': 'verb',
210                                                'ssh_options': 'options'})
211        host_obj = factory.create_host(machine)
212        self.assertEqual(host_obj._init_args['user'], 'somebody')
213        self.assertEqual(host_obj._init_args['port'], 100)
214        self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
215        self.assertEqual(host_obj._init_args['ssh_options'], 'options')
216
217
218if __name__ == '__main__':
219    unittest.main()
220
221