1# Copyright 2019 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Ssh Utilities."""
15from __future__ import print_function
16import logging
17
18import subprocess
19import sys
20import threading
21
22from acloud import errors
23from acloud.internal import constants
24from acloud.internal.lib import utils
25
26logger = logging.getLogger(__name__)
27
28_SSH_CMD = ("-i %(rsa_key_file)s "
29            "-q -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no")
30_SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
31_SSH_CMD_MAX_RETRY = 5
32_SSH_CMD_RETRY_SLEEP = 3
33_CONNECTION_TIMEOUT = 10
34
35
36def _SshCallWait(cmd, timeout=None):
37    """Runs a single SSH command.
38
39    - SSH returns code 0 for "Successful execution".
40    - Use wait() until the process is complete without receiving any output.
41
42    Args:
43        cmd: String of the full SSH command to run, including the SSH binary
44             and its arguments.
45        timeout: Optional integer, number of seconds to give
46
47    Returns:
48        An exit status of 0 indicates that it ran successfully.
49    """
50    logger.info("Running command \"%s\"", cmd)
51    process = subprocess.Popen(cmd, shell=True, stdin=None,
52                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
53    if timeout:
54        # TODO: if process is killed, out error message to log.
55        timer = threading.Timer(timeout, process.kill)
56        timer.start()
57    process.wait()
58    if timeout:
59        timer.cancel()
60    return process.returncode
61
62
63def _SshCall(cmd, timeout=None):
64    """Runs a single SSH command.
65
66    - SSH returns code 0 for "Successful execution".
67    - Use communicate() until the process and the child thread are complete.
68
69    Args:
70        cmd: String of the full SSH command to run, including the SSH binary
71             and its arguments.
72        timeout: Optional integer, number of seconds to give
73
74    Returns:
75        An exit status of 0 indicates that it ran successfully.
76    """
77    logger.info("Running command \"%s\"", cmd)
78    process = subprocess.Popen(cmd, shell=True, stdin=None,
79                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
80    if timeout:
81        # TODO: if process is killed, out error message to log.
82        timer = threading.Timer(timeout, process.kill)
83        timer.start()
84    process.communicate()
85    if timeout:
86        timer.cancel()
87    return process.returncode
88
89
90def _SshLogOutput(cmd, timeout=None, show_output=False):
91    """Runs a single SSH command while logging its output and processes its return code.
92
93    Output is streamed to the log at the debug level for more interactive debugging.
94    SSH returns error code 255 for "failed to connect", so this is interpreted as a failure in
95    SSH rather than a failure on the target device and this is converted to a different exception
96    type.
97
98    Args:
99        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
100        timeout: Optional integer, number of seconds to give.
101        show_output: Boolean, True to show command output in screen.
102
103    Raises:
104        errors.DeviceConnectionError: Failed to connect to the GCE instance.
105        subprocess.CalledProc: The process exited with an error on the instance.
106    """
107    # Use "exec" to let cmd to inherit the shell process, instead of having the
108    # shell launch a child process which does not get killed.
109    cmd = "exec " + cmd
110    logger.info("Running command \"%s\"", cmd)
111    process = subprocess.Popen(cmd, shell=True, stdin=None,
112                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
113                               universal_newlines=True)
114    if timeout:
115        # TODO: if process is killed, out error message to log.
116        timer = threading.Timer(timeout, process.kill)
117        timer.start()
118    stdout, _ = process.communicate()
119    if stdout:
120        if show_output or process.returncode != 0:
121            print(stdout.strip(), file=sys.stderr)
122        else:
123            # fetch_cvd and launch_cvd can be noisy, so left at debug
124            logger.debug(stdout.strip())
125    if timeout:
126        timer.cancel()
127    if process.returncode == 255:
128        raise errors.DeviceConnectionError(
129            "Failed to send command to instance (%s)" % cmd)
130    elif process.returncode != 0:
131        raise subprocess.CalledProcessError(process.returncode, cmd)
132
133
134def ShellCmdWithRetry(cmd, timeout=None, show_output=False,
135                      retry=_SSH_CMD_MAX_RETRY):
136    """Runs a shell command on remote device.
137
138    If the network is unstable and causes SSH connect fail, it will retry. When
139    it retry in a short time, you may encounter unstable network. We will use
140    the mechanism of RETRY_BACKOFF_FACTOR. The retry time for each failure is
141    times * retries.
142
143    Args:
144        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
145        timeout: Optional integer, number of seconds to give.
146        show_output: Boolean, True to show command output in screen.
147        retry: Integer, the retry times.
148
149    Raises:
150        errors.DeviceConnectionError: For any non-zero return code of
151                                      remote_cmd.
152    """
153    utils.RetryExceptionType(
154        exception_types=(errors.DeviceConnectionError, subprocess.CalledProcessError),
155        max_retries=retry,
156        functor=_SshLogOutput,
157        sleep_multiplier=_SSH_CMD_RETRY_SLEEP,
158        retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
159        cmd=cmd,
160        timeout=timeout,
161        show_output=show_output)
162
163
164class IP(object):
165    """ A class that control the IP address."""
166    def __init__(self, external=None, internal=None, ip=None):
167        """Init for IP.
168            Args:
169                external: String, external ip.
170                internal: String, internal ip.
171                ip: String, default ip to set for either external and internal
172                if neither is set.
173        """
174        self.external = external or ip
175        self.internal = internal or ip
176
177
178class Ssh(object):
179    """A class that control the remote instance via the IP address.
180
181    Attributes:
182        _ip: an IP object.
183        _user: String of user login into the instance.
184        _ssh_private_key_path: Path to the private key file.
185        _extra_args_ssh_tunnel: String, extra args for ssh or scp.
186    """
187    def __init__(self, ip, user, ssh_private_key_path,
188                 extra_args_ssh_tunnel=None, report_internal_ip=False):
189        self._ip = ip.internal if report_internal_ip else ip.external
190        self._user = user
191        self._ssh_private_key_path = ssh_private_key_path
192        self._extra_args_ssh_tunnel = extra_args_ssh_tunnel
193
194    def Run(self, target_command, timeout=None, show_output=False,
195            retry=_SSH_CMD_MAX_RETRY):
196        """Run a shell command over SSH on a remote instance.
197
198        Example:
199            ssh:
200                base_cmd_list is ["ssh", "-i", "~/private_key_path" ,"-l" , "user", "1.1.1.1"]
201                target_command is "remote command"
202            scp:
203                base_cmd_list is ["scp", "-i", "~/private_key_path"]
204                target_command is "{src_file} {dst_file}"
205
206        Args:
207            target_command: String, text of command to run on the remote instance.
208            timeout: Integer, the maximum time to wait for the command to respond.
209            show_output: Boolean, True to show command output in screen.
210            retry: Integer, the retry times.
211        """
212        ShellCmdWithRetry(self.GetBaseCmd(constants.SSH_BIN) + " " + target_command,
213                          timeout,
214                          show_output,
215                          retry)
216
217    def GetBaseCmd(self, execute_bin):
218        """Get a base command over SSH on a remote instance.
219
220        Example:
221            execute bin is ssh:
222                ssh -i ~/private_key_path $extra_args -l user 1.1.1.1
223            execute bin is scp:
224                scp -i ~/private_key_path $extra_args
225
226        Args:
227            execute_bin: String, execute type, e.g. ssh or scp.
228
229        Returns:
230            Strings of base connection command.
231
232        Raises:
233            errors.UnknownType: Don't support the execute bin.
234        """
235        base_cmd = [utils.FindExecutable(execute_bin)]
236        base_cmd.append(_SSH_CMD % {"rsa_key_file": self._ssh_private_key_path})
237        if self._extra_args_ssh_tunnel:
238            base_cmd.append(self._extra_args_ssh_tunnel)
239
240        if execute_bin == constants.SSH_BIN:
241            base_cmd.append(_SSH_IDENTITY %
242                            {"login_user":self._user, "ip_addr":self._ip})
243            return " ".join(base_cmd)
244        if execute_bin == constants.SCP_BIN:
245            return " ".join(base_cmd)
246
247        raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
248
249    def GetCmdOutput(self, cmd):
250        """Runs a single SSH command and get its output.
251
252        Args:
253            cmd: String, text of command to run on the remote instance.
254
255        Returns:
256            String of the command output.
257        """
258        ssh_cmd = "exec " + self.GetBaseCmd(constants.SSH_BIN) + " " + cmd
259        logger.info("Running command \"%s\"", ssh_cmd)
260        process = subprocess.Popen(ssh_cmd, shell=True, stdin=None,
261                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
262                                   universal_newlines=True)
263        stdout, _ = process.communicate()
264        return stdout
265
266    def CheckSshConnection(self, timeout):
267        """Run remote 'uptime' ssh command to check ssh connection.
268
269        Args:
270            timeout: Integer, the maximum time to wait for the command to respond.
271
272        Raises:
273            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
274        """
275        remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
276        remote_cmd.append("uptime")
277
278        if _SshCallWait(" ".join(remote_cmd), timeout) == 0:
279            return
280        raise errors.DeviceConnectionError(
281            "Ssh isn't ready in the remote instance.")
282
283    @utils.TimeExecute(function_description="Waiting for SSH server")
284    def WaitForSsh(self, timeout=None, max_retry=_SSH_CMD_MAX_RETRY):
285        """Wait until the remote instance is ready to accept commands over SSH.
286
287        Args:
288            timeout: Integer, the maximum time in seconds to wait for the
289                     command to respond.
290            max_retry: Integer, the maximum number of retry.
291
292        Raises:
293            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
294        """
295        ssh_timeout = timeout or constants.DEFAULT_SSH_TIMEOUT
296        sleep_multiplier = ssh_timeout / sum(range(max_retry + 1))
297        logger.debug("Retry with interval time: %s secs", str(sleep_multiplier))
298        utils.RetryExceptionType(
299            exception_types=errors.DeviceConnectionError,
300            max_retries=max_retry,
301            functor=self.CheckSshConnection,
302            sleep_multiplier=sleep_multiplier,
303            retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
304            timeout=_CONNECTION_TIMEOUT)
305
306    def ScpPushFile(self, src_file, dst_file):
307        """Scp push file to remote.
308
309        Args:
310            src_file: The source file path to be pulled.
311            dst_file: The destination file path the file is pulled to.
312        """
313        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
314        scp_command.append(src_file)
315        scp_command.append("%s@%s:%s" %(self._user, self._ip, dst_file))
316        ShellCmdWithRetry(" ".join(scp_command))
317
318    def ScpPullFile(self, src_file, dst_file):
319        """Scp pull file from remote.
320
321        Args:
322            src_file: The source file path to be pulled.
323            dst_file: The destination file path the file is pulled to.
324        """
325        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
326        scp_command.append("%s@%s:%s" %(self._user, self._ip, src_file))
327        scp_command.append(dst_file)
328        ShellCmdWithRetry(" ".join(scp_command))
329