1# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import common
6import inspect, new, socket, sys
7
8from autotest_lib.client.bin import utils
9from autotest_lib.cli import host, rpc
10from autotest_lib.server import hosts
11from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
12from autotest_lib.client.common_lib import error, host_protections
13
14
15# In order for hosts to work correctly, some of its variables must be setup.
16hosts.factory.ssh_user = 'root'
17hosts.factory.ssh_pass = ''
18hosts.factory.ssh_port = 22
19hosts.factory.ssh_verbosity_flag = ''
20hosts.factory.ssh_options = ''
21
22
23# pylint: disable=missing-docstring
24class site_host(host.host):
25    pass
26
27
28class site_host_create(site_host, host.host_create):
29    """
30    site_host_create subclasses host_create in host.py.
31    """
32
33    @classmethod
34    def construct_without_parse(
35            cls, web_server, hosts, platform=None,
36            locked=False, lock_reason='', labels=[], acls=[],
37            protection=host_protections.Protection.NO_PROTECTION):
38        """Construct an site_host_create object and fill in data from args.
39
40        Do not need to call parse after the construction.
41
42        Return an object of site_host_create ready to execute.
43
44        @param web_server: A string specifies the autotest webserver url.
45            It is needed to setup comm to make rpc.
46        @param hosts: A list of hostnames as strings.
47        @param platform: A string or None.
48        @param locked: A boolean.
49        @param lock_reason: A string.
50        @param labels: A list of labels as strings.
51        @param acls: A list of acls as strings.
52        @param protection: An enum defined in host_protections.
53        """
54        obj = cls()
55        obj.web_server = web_server
56        try:
57            # Setup stuff needed for afe comm.
58            obj.afe = rpc.afe_comm(web_server)
59        except rpc.AuthError, s:
60            obj.failure(str(s), fatal=True)
61        obj.hosts = hosts
62        obj.platform = platform
63        obj.locked = locked
64        if locked and lock_reason.strip():
65            obj.data['lock_reason'] = lock_reason.strip()
66        obj.labels = labels
67        obj.acls = acls
68        if protection:
69            obj.data['protection'] = protection
70        # TODO(kevcheng): Update the admin page to take in serials?
71        obj.serials = None
72        return obj
73
74
75    def _execute_add_one_host(self, host):
76        # Always add the hosts as locked to avoid the host
77        # being picked up by the scheduler before it's ACL'ed.
78        self.data['locked'] = True
79        if not self.locked:
80            self.data['lock_reason'] = 'Forced lock on device creation'
81        self.execute_rpc('add_host', hostname=host,
82                         status="Ready", **self.data)
83        # If there are labels avaliable for host, use them.
84        host_info = self.host_info_map[host]
85        labels = set(self.labels)
86        if host_info.labels:
87            labels.update(host_info.labels)
88        # Now add the platform label.
89        # If a platform was not provided and we were able to retrieve it
90        # from the host, use the retrieved platform.
91        platform = self.platform if self.platform else host_info.platform
92        if platform:
93            labels.add(platform)
94
95        if len(labels):
96            self.execute_rpc('host_add_labels', id=host, labels=list(labels))
97
98        if self.serials:
99            afe = frontend_wrappers.RetryingAFE(timeout_min=5, delay_sec=10)
100            afe.set_host_attribute('serials', ','.join(self.serials),
101                                   hostname=host)
102
103
104    def execute(self):
105        # Check to see if the platform or any other labels can be grabbed from
106        # the hosts.
107        self.host_info_map = {}
108        for host in self.hosts:
109            try:
110                if utils.ping(host, tries=1, deadline=1) == 0:
111                    if self.serials and len(self.serials) > 1:
112                        host_dut = hosts.create_testbed(
113                                host, adb_serials=self.serials)
114                    else:
115                        adb_serial = None
116                        if self.serials:
117                            adb_serial = self.serials[0]
118                        host_dut = hosts.create_host(host,
119                                                     adb_serial=adb_serial)
120                    host_info = host_information(host,
121                                                 host_dut.get_platform(),
122                                                 host_dut.get_labels())
123                else:
124                    # Can't ping the host, use default information.
125                    host_info = host_information(host, None, [])
126            except (socket.gaierror, error.AutoservRunError,
127                    error.AutoservSSHTimeout):
128                # We may be adding a host that does not exist yet or we can't
129                # reach due to hostname/address issues or if the host is down.
130                host_info = host_information(host, None, [])
131            self.host_info_map[host] = host_info
132        # We need to check if these labels & ACLs exist,
133        # and create them if not.
134        if self.platform:
135            self.check_and_create_items('get_labels', 'add_label',
136                                        [self.platform],
137                                        platform=True)
138        else:
139            # No platform was provided so check and create the platform label
140            # for each host.
141            platforms = []
142            for host_info in self.host_info_map.values():
143                if host_info.platform and host_info.platform not in platforms:
144                    platforms.append(host_info.platform)
145            if platforms:
146                self.check_and_create_items('get_labels', 'add_label',
147                                            platforms,
148                                            platform=True)
149        labels_to_check_and_create = self.labels[:]
150        for host_info in self.host_info_map.values():
151            labels_to_check_and_create = (host_info.labels +
152                                          labels_to_check_and_create)
153        if labels_to_check_and_create:
154            self.check_and_create_items('get_labels', 'add_label',
155                                        labels_to_check_and_create,
156                                        platform=False)
157
158        if self.acls:
159            self.check_and_create_items('get_acl_groups',
160                                        'add_acl_group',
161                                        self.acls)
162
163        return self._execute_add_hosts()
164
165
166class host_information(object):
167    """Store host information so we don't have to keep looking it up."""
168
169
170    def __init__(self, hostname, platform, labels):
171        self.hostname = hostname
172        self.platform = platform
173        self.labels = labels
174
175
176# Any classes we don't override in host should be copied automatically
177for cls in [getattr(host, n) for n in dir(host) if not n.startswith("_")]:
178    if not inspect.isclass(cls):
179        continue
180    cls_name = cls.__name__
181    site_cls_name = 'site_' + cls_name
182    if hasattr(sys.modules[__name__], site_cls_name):
183        continue
184    bases = (site_host, cls)
185    members = {'__doc__': cls.__doc__}
186    site_cls = new.classobj(site_cls_name, bases, members)
187    setattr(sys.modules[__name__], site_cls_name, site_cls)
188