1# Copyright 2019 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Ssh Utilities."""
15from __future__ import print_function
16import logging
17
18import subprocess
19import sys
20import threading
21
22from acloud import errors
23from acloud.internal import constants
24from acloud.internal.lib import utils
25
26logger = logging.getLogger(__name__)
27
28_SSH_CMD = ("-i %(rsa_key_file)s "
29            "-q -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no")
30_SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
31_SSH_CMD_MAX_RETRY = 5
32_SSH_CMD_RETRY_SLEEP = 3
33_WAIT_FOR_SSH_MAX_TIMEOUT = 60
34
35
36def _SshCallWait(cmd, timeout=None):
37    """Runs a single SSH command.
38
39    - SSH returns code 0 for "Successful execution".
40    - Use wait() until the process is complete without receiving any output.
41
42    Args:
43        cmd: String of the full SSH command to run, including the SSH binary
44             and its arguments.
45        timeout: Optional integer, number of seconds to give
46
47    Returns:
48        An exit status of 0 indicates that it ran successfully.
49    """
50    logger.info("Running command \"%s\"", cmd)
51    process = subprocess.Popen(cmd, shell=True, stdin=None,
52                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
53    if timeout:
54        # TODO: if process is killed, out error message to log.
55        timer = threading.Timer(timeout, process.kill)
56        timer.start()
57    process.wait()
58    if timeout:
59        timer.cancel()
60    return process.returncode
61
62
63def _SshCall(cmd, timeout=None):
64    """Runs a single SSH command.
65
66    - SSH returns code 0 for "Successful execution".
67    - Use communicate() until the process and the child thread are complete.
68
69    Args:
70        cmd: String of the full SSH command to run, including the SSH binary
71             and its arguments.
72        timeout: Optional integer, number of seconds to give
73
74    Returns:
75        An exit status of 0 indicates that it ran successfully.
76    """
77    logger.info("Running command \"%s\"", cmd)
78    process = subprocess.Popen(cmd, shell=True, stdin=None,
79                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
80    if timeout:
81        # TODO: if process is killed, out error message to log.
82        timer = threading.Timer(timeout, process.kill)
83        timer.start()
84    process.communicate()
85    if timeout:
86        timer.cancel()
87    return process.returncode
88
89
90def _SshLogOutput(cmd, timeout=None, show_output=False):
91    """Runs a single SSH command while logging its output and processes its return code.
92
93    Output is streamed to the log at the debug level for more interactive debugging.
94    SSH returns error code 255 for "failed to connect", so this is interpreted as a failure in
95    SSH rather than a failure on the target device and this is converted to a different exception
96    type.
97
98    Args:
99        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
100        timeout: Optional integer, number of seconds to give.
101        show_output: Boolean, True to show command output in screen.
102
103    Raises:
104        errors.DeviceConnectionError: Failed to connect to the GCE instance.
105        subprocess.CalledProc: The process exited with an error on the instance.
106    """
107    # Use "exec" to let cmd to inherit the shell process, instead of having the
108    # shell launch a child process which does not get killed.
109    cmd = "exec " + cmd
110    logger.info("Running command \"%s\"", cmd)
111    process = subprocess.Popen(cmd, shell=True, stdin=None,
112                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
113    if timeout:
114        # TODO: if process is killed, out error message to log.
115        timer = threading.Timer(timeout, process.kill)
116        timer.start()
117    stdout, _ = process.communicate()
118    if stdout:
119        if show_output or process.returncode != 0:
120            print(stdout.strip(), file=sys.stderr)
121        else:
122            # fetch_cvd and launch_cvd can be noisy, so left at debug
123            logger.debug(stdout.strip())
124    if timeout:
125        timer.cancel()
126    if process.returncode == 255:
127        raise errors.DeviceConnectionError(
128            "Failed to send command to instance (%s)" % cmd)
129    elif process.returncode != 0:
130        raise subprocess.CalledProcessError(process.returncode, cmd)
131
132
133def ShellCmdWithRetry(cmd, timeout=None, show_output=False):
134    """Runs a shell command on remote device.
135
136    If the network is unstable and causes SSH connect fail, it will retry. When
137    it retry in a short time, you may encounter unstable network. We will use
138    the mechanism of RETRY_BACKOFF_FACTOR. The retry time for each failure is
139    times * retries.
140
141    Args:
142        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
143        timeout: Optional integer, number of seconds to give.
144        show_output: Boolean, True to show command output in screen.
145
146    Raises:
147        errors.DeviceConnectionError: For any non-zero return code of
148                                      remote_cmd.
149    """
150    utils.RetryExceptionType(
151        exception_types=errors.DeviceConnectionError,
152        max_retries=_SSH_CMD_MAX_RETRY,
153        functor=_SshLogOutput,
154        sleep_multiplier=_SSH_CMD_RETRY_SLEEP,
155        retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
156        cmd=cmd,
157        timeout=timeout,
158        show_output=show_output)
159
160
161class IP(object):
162    """ A class that control the IP address."""
163    def __init__(self, external=None, internal=None, ip=None):
164        """Init for IP.
165            Args:
166                external: String, external ip.
167                internal: String, internal ip.
168                ip: String, default ip to set for either external and internal
169                if neither is set.
170        """
171        self.external = external or ip
172        self.internal = internal or ip
173
174
175class Ssh(object):
176    """A class that control the remote instance via the IP address.
177
178    Attributes:
179        _ip: an IP object.
180        _user: String of user login into the instance.
181        _ssh_private_key_path: Path to the private key file.
182        _extra_args_ssh_tunnel: String, extra args for ssh or scp.
183    """
184    def __init__(self, ip, user, ssh_private_key_path,
185                 extra_args_ssh_tunnel=None, report_internal_ip=False):
186        self._ip = ip.internal if report_internal_ip else ip.external
187        self._user = user
188        self._ssh_private_key_path = ssh_private_key_path
189        self._extra_args_ssh_tunnel = extra_args_ssh_tunnel
190
191    def Run(self, target_command, timeout=None, show_output=False):
192        """Run a shell command over SSH on a remote instance.
193
194        Example:
195            ssh:
196                base_cmd_list is ["ssh", "-i", "~/private_key_path" ,"-l" , "user", "1.1.1.1"]
197                target_command is "remote command"
198            scp:
199                base_cmd_list is ["scp", "-i", "~/private_key_path"]
200                target_command is "{src_file} {dst_file}"
201
202        Args:
203            target_command: String, text of command to run on the remote instance.
204            timeout: Integer, the maximum time to wait for the command to respond.
205            show_output: Boolean, True to show command output in screen.
206        """
207        ShellCmdWithRetry(self.GetBaseCmd(constants.SSH_BIN) + " " + target_command,
208                          timeout,
209                          show_output)
210
211    def GetBaseCmd(self, execute_bin):
212        """Get a base command over SSH on a remote instance.
213
214        Example:
215            execute bin is ssh:
216                ssh -i ~/private_key_path $extra_args -l user 1.1.1.1
217            execute bin is scp:
218                scp -i ~/private_key_path $extra_args
219
220        Args:
221            execute_bin: String, execute type, e.g. ssh or scp.
222
223        Returns:
224            Strings of base connection command.
225
226        Raises:
227            errors.UnknownType: Don't support the execute bin.
228        """
229        base_cmd = [utils.FindExecutable(execute_bin)]
230        base_cmd.append(_SSH_CMD % {"rsa_key_file": self._ssh_private_key_path})
231        if self._extra_args_ssh_tunnel:
232            base_cmd.append(self._extra_args_ssh_tunnel)
233
234        if execute_bin == constants.SSH_BIN:
235            base_cmd.append(_SSH_IDENTITY %
236                            {"login_user":self._user, "ip_addr":self._ip})
237            return " ".join(base_cmd)
238        if execute_bin == constants.SCP_BIN:
239            return " ".join(base_cmd)
240
241        raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
242
243    def CheckSshConnection(self, timeout):
244        """Run remote 'uptime' ssh command to check ssh connection.
245
246        Args:
247            timeout: Integer, the maximum time to wait for the command to respond.
248
249        Raises:
250            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
251        """
252        remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
253        remote_cmd.append("uptime")
254
255        if _SshCallWait(" ".join(remote_cmd), timeout) == 0:
256            return
257        raise errors.DeviceConnectionError(
258            "Ssh isn't ready in the remote instance.")
259
260    @utils.TimeExecute(function_description="Waiting for SSH server")
261    def WaitForSsh(self, timeout=None, sleep_for_retry=_SSH_CMD_RETRY_SLEEP,
262                   max_retry=_SSH_CMD_MAX_RETRY):
263        """Wait until the remote instance is ready to accept commands over SSH.
264
265        Args:
266            timeout: Integer, the maximum time in seconds to wait for the
267                     command to respond.
268            sleep_for_retry: Integer, the sleep time in seconds for retry.
269            max_retry: Integer, the maximum number of retry.
270
271        Raises:
272            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
273        """
274        timeout_one_round = timeout / max_retry if timeout else None
275        utils.RetryExceptionType(
276            exception_types=errors.DeviceConnectionError,
277            max_retries=max_retry,
278            functor=self.CheckSshConnection,
279            sleep_multiplier=sleep_for_retry,
280            retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
281            timeout=timeout_one_round or _WAIT_FOR_SSH_MAX_TIMEOUT)
282
283    def ScpPushFile(self, src_file, dst_file):
284        """Scp push file to remote.
285
286        Args:
287            src_file: The source file path to be pulled.
288            dst_file: The destination file path the file is pulled to.
289        """
290        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
291        scp_command.append(src_file)
292        scp_command.append("%s@%s:%s" %(self._user, self._ip, dst_file))
293        ShellCmdWithRetry(" ".join(scp_command))
294
295    def ScpPullFile(self, src_file, dst_file):
296        """Scp pull file from remote.
297
298        Args:
299            src_file: The source file path to be pulled.
300            dst_file: The destination file path the file is pulled to.
301        """
302        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
303        scp_command.append("%s@%s:%s" %(self._user, self._ip, src_file))
304        scp_command.append(dst_file)
305        ShellCmdWithRetry(" ".join(scp_command))
306