1#!/usr/bin/env python
2#
3# Copyright 2018 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Common operations between managing GCE and Cuttlefish devices.
17
18This module provides the common operations between managing GCE (device_driver)
19and Cuttlefish (create_cuttlefish_action) devices. Should not be called
20directly.
21"""
22
23import logging
24import os
25
26from acloud import errors
27from acloud.public import avd
28from acloud.public import report
29from acloud.internal import constants
30from acloud.internal.lib import utils
31from acloud.internal.lib.adb_tools import AdbTools
32
33
34logger = logging.getLogger(__name__)
35_ACLOUD_BOOT_UP_ERROR = "ACLOUD_BOOT_UP_ERROR"
36_ACLOUD_DOWNLOAD_ARTIFACT_ERROR = "ACLOUD_DOWNLOAD_ARTIFACT_ERROR"
37_ACLOUD_GENERIC_ERROR = "ACLOUD_GENERIC_ERROR"
38_ACLOUD_SSH_CONNECT_ERROR = "ACLOUD_SSH_CONNECT_ERROR"
39# Error type of GCE quota error.
40_GCE_QUOTA_ERROR = "GCE_QUOTA_ERROR"
41_GCE_QUOTA_ERROR_MSG = "Quota exceeded for quota"
42_DICT_ERROR_TYPE = {
43    constants.STAGE_INIT: "ACLOUD_INIT_ERROR",
44    constants.STAGE_GCE: "ACLOUD_CREATE_GCE_ERROR",
45    constants.STAGE_SSH_CONNECT: _ACLOUD_SSH_CONNECT_ERROR,
46    constants.STAGE_ARTIFACT: _ACLOUD_DOWNLOAD_ARTIFACT_ERROR,
47    constants.STAGE_BOOT_UP: _ACLOUD_BOOT_UP_ERROR,
48}
49
50
51def CreateSshKeyPairIfNecessary(cfg):
52    """Create ssh key pair if necessary.
53
54    Args:
55        cfg: An Acloudconfig instance.
56
57    Raises:
58        error.DriverError: If it falls into an unexpected condition.
59    """
60    if not cfg.ssh_public_key_path:
61        logger.warning(
62            "ssh_public_key_path is not specified in acloud config. "
63            "Project-wide public key will "
64            "be used when creating AVD instances. "
65            "Please ensure you have the correct private half of "
66            "a project-wide public key if you want to ssh into the "
67            "instances after creation.")
68    elif cfg.ssh_public_key_path and not cfg.ssh_private_key_path:
69        logger.warning(
70            "Only ssh_public_key_path is specified in acloud config, "
71            "but ssh_private_key_path is missing. "
72            "Please ensure you have the correct private half "
73            "if you want to ssh into the instances after creation.")
74    elif cfg.ssh_public_key_path and cfg.ssh_private_key_path:
75        utils.CreateSshKeyPairIfNotExist(cfg.ssh_private_key_path,
76                                         cfg.ssh_public_key_path)
77    else:
78        # Should never reach here.
79        raise errors.DriverError(
80            "Unexpected error in CreateSshKeyPairIfNecessary")
81
82
83class DevicePool:
84    """A class that manages a pool of virtual devices.
85
86    Attributes:
87        devices: A list of devices in the pool.
88    """
89
90    def __init__(self, device_factory, devices=None):
91        """Constructs a new DevicePool.
92
93        Args:
94            device_factory: A device factory capable of producing a goldfish or
95                cuttlefish device. The device factory must expose an attribute with
96                the credentials that can be used to retrieve information from the
97                constructed device.
98            devices: List of devices managed by this pool.
99        """
100        self._devices = devices or []
101        self._device_factory = device_factory
102        self._compute_client = device_factory.GetComputeClient()
103
104    def CreateDevices(self, num):
105        """Creates |num| devices for given build_target and build_id.
106
107        Args:
108            num: Number of devices to create.
109        """
110        # Create host instances for cuttlefish/goldfish device.
111        # Currently one instance supports only 1 device.
112        for _ in range(num):
113            instance = self._device_factory.CreateInstance()
114            ip = self._compute_client.GetInstanceIP(instance)
115            time_info = self._compute_client.execution_time if hasattr(
116                self._compute_client, "execution_time") else {}
117            stage = self._compute_client.stage if hasattr(
118                self._compute_client, "stage") else 0
119            self.devices.append(
120                avd.AndroidVirtualDevice(ip=ip, instance_name=instance,
121                                         time_info=time_info, stage=stage))
122
123    @utils.TimeExecute(function_description="Waiting for AVD(s) to boot up",
124                       result_evaluator=utils.BootEvaluator)
125    def WaitForBoot(self, boot_timeout_secs):
126        """Waits for all devices to boot up.
127
128        Args:
129            boot_timeout_secs: Integer, the maximum time in seconds used to
130                               wait for the AVD to boot.
131
132        Returns:
133            A dictionary that contains all the failures.
134            The key is the name of the instance that fails to boot,
135            and the value is an errors.DeviceBootError object.
136        """
137        failures = {}
138        for device in self._devices:
139            try:
140                self._compute_client.WaitForBoot(device.instance_name, boot_timeout_secs)
141            except errors.DeviceBootError as e:
142                failures[device.instance_name] = e
143        return failures
144
145    def UpdateReport(self, reporter):
146        """Update report from compute client.
147
148        Args:
149            reporter: Report object.
150        """
151        reporter.UpdateData(self._compute_client.dict_report)
152
153    def CollectSerialPortLogs(self, output_file,
154                              port=constants.DEFAULT_SERIAL_PORT):
155        """Tar the instance serial logs into specified output_file.
156
157        Args:
158            output_file: String, the output tar file path
159            port: The serial port number to be collected
160        """
161        # For emulator, the serial log is the virtual host serial log.
162        # For GCE AVD device, the serial log is the AVD device serial log.
163        with utils.TempDir() as tempdir:
164            src_dict = {}
165            for device in self._devices:
166                logger.info("Store instance %s serial port %s output to %s",
167                            device.instance_name, port, output_file)
168                serial_log = self._compute_client.GetSerialPortOutput(
169                    instance=device.instance_name, port=port)
170                file_name = "%s_serial_%s.log" % (device.instance_name, port)
171                file_path = os.path.join(tempdir, file_name)
172                src_dict[file_path] = file_name
173                with open(file_path, "w") as f:
174                    f.write(serial_log.encode("utf-8"))
175            utils.MakeTarFile(src_dict, output_file)
176
177    def SetDeviceBuildInfo(self):
178        """Add devices build info."""
179        for device in self._devices:
180            device.build_info = self._device_factory.GetBuildInfoDict()
181
182    @property
183    def devices(self):
184        """Returns a list of devices in the pool.
185
186        Returns:
187            A list of devices in the pool.
188        """
189        return self._devices
190
191def _GetErrorType(error):
192    """Get proper error type from the exception error.
193
194    Args:
195        error: errors object.
196
197    Returns:
198        String of error type. e.g. "ACLOUD_BOOT_UP_ERROR".
199    """
200    if isinstance(error, errors.CheckGCEZonesQuotaError):
201        return _GCE_QUOTA_ERROR
202    if isinstance(error, errors.DownloadArtifactError):
203        return _ACLOUD_DOWNLOAD_ARTIFACT_ERROR
204    if isinstance(error, errors.DeviceConnectionError):
205        return _ACLOUD_SSH_CONNECT_ERROR
206    if _GCE_QUOTA_ERROR_MSG in str(error):
207        return _GCE_QUOTA_ERROR
208    return _ACLOUD_GENERIC_ERROR
209
210# pylint: disable=too-many-locals,unused-argument,too-many-branches
211def CreateDevices(command, cfg, device_factory, num, avd_type,
212                  report_internal_ip=False, autoconnect=False,
213                  serial_log_file=None, client_adb_port=None,
214                  boot_timeout_secs=None, unlock_screen=False,
215                  wait_for_boot=True, connect_webrtc=False):
216    """Create a set of devices using the given factory.
217
218    Main jobs in create devices.
219        1. Create GCE instance: Launch instance in GCP(Google Cloud Platform).
220        2. Starting up AVD: Wait device boot up.
221
222    Args:
223        command: The name of the command, used for reporting.
224        cfg: An AcloudConfig instance.
225        device_factory: A factory capable of producing a single device.
226        num: The number of devices to create.
227        avd_type: String, the AVD type(cuttlefish, goldfish...).
228        report_internal_ip: Boolean to report the internal ip instead of
229                            external ip.
230        serial_log_file: String, the file path to tar the serial logs.
231        autoconnect: Boolean, whether to auto connect to device.
232        client_adb_port: Integer, Specify port for adb forwarding.
233        boot_timeout_secs: Integer, boot timeout secs.
234        unlock_screen: Boolean, whether to unlock screen after invoke vnc client.
235        wait_for_boot: Boolean, True to check serial log include boot up
236                       message.
237        connect_webrtc: Boolean, whether to auto connect webrtc to device.
238
239    Raises:
240        errors: Create instance fail.
241
242    Returns:
243        A Report instance.
244    """
245    reporter = report.Report(command=command)
246    try:
247        CreateSshKeyPairIfNecessary(cfg)
248        device_pool = DevicePool(device_factory)
249        device_pool.CreateDevices(num)
250        device_pool.SetDeviceBuildInfo()
251        if wait_for_boot:
252            failures = device_pool.WaitForBoot(boot_timeout_secs)
253        else:
254            failures = device_factory.GetFailures()
255
256        if failures:
257            reporter.SetStatus(report.Status.BOOT_FAIL)
258        else:
259            reporter.SetStatus(report.Status.SUCCESS)
260
261        # Collect logs
262        if serial_log_file:
263            device_pool.CollectSerialPortLogs(
264                serial_log_file, port=constants.DEFAULT_SERIAL_PORT)
265
266        device_pool.UpdateReport(reporter)
267        # Write result to report.
268        for device in device_pool.devices:
269            ip = (device.ip.internal if report_internal_ip
270                  else device.ip.external)
271            device_dict = {
272                "ip": ip,
273                "instance_name": device.instance_name
274            }
275            if device.build_info:
276                device_dict.update(device.build_info)
277            if device.time_info:
278                device_dict.update(device.time_info)
279            if autoconnect:
280                forwarded_ports = utils.AutoConnect(
281                    ip_addr=ip,
282                    rsa_key_file=cfg.ssh_private_key_path,
283                    target_vnc_port=utils.AVD_PORT_DICT[avd_type].vnc_port,
284                    target_adb_port=utils.AVD_PORT_DICT[avd_type].adb_port,
285                    ssh_user=constants.GCE_USER,
286                    client_adb_port=client_adb_port,
287                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
288                device_dict[constants.VNC_PORT] = forwarded_ports.vnc_port
289                device_dict[constants.ADB_PORT] = forwarded_ports.adb_port
290                device_dict[constants.DEVICE_SERIAL] = (
291                    constants.REMOTE_INSTANCE_ADB_SERIAL %
292                    forwarded_ports.adb_port)
293                if unlock_screen:
294                    AdbTools(forwarded_ports.adb_port).AutoUnlockScreen()
295            if connect_webrtc:
296                utils.EstablishWebRTCSshTunnel(
297                    ip_addr=ip,
298                    rsa_key_file=cfg.ssh_private_key_path,
299                    ssh_user=constants.GCE_USER,
300                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
301            if device.instance_name in failures:
302                reporter.SetErrorType(_ACLOUD_BOOT_UP_ERROR)
303                if device.stage:
304                    reporter.SetErrorType(_DICT_ERROR_TYPE[device.stage])
305                reporter.AddData(key="devices_failing_boot", value=device_dict)
306                reporter.AddError(str(failures[device.instance_name]))
307            else:
308                reporter.AddData(key="devices", value=device_dict)
309    except (errors.DriverError, errors.CheckGCEZonesQuotaError) as e:
310        reporter.SetErrorType(_GetErrorType(e))
311        reporter.AddError(str(e))
312        reporter.SetStatus(report.Status.FAIL)
313    return reporter
314