1# Copyright 2016 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import os
17import re
18import shutil
19import tempfile
20import threading
21import time
22import uuid
23
24from acts import logger
25from acts.controllers.utils_lib import host_utils
26from acts.controllers.utils_lib.ssh import formatter
27from acts.libs.proc import job
28
29
30class Error(Exception):
31    """An error occurred during an ssh operation."""
32
33
34class CommandError(Exception):
35    """An error occurred with the command.
36
37    Attributes:
38        result: The results of the ssh command that had the error.
39    """
40
41    def __init__(self, result):
42        """
43        Args:
44            result: The result of the ssh command that created the problem.
45        """
46        self.result = result
47
48    def __str__(self):
49        return 'cmd: %s\nstdout: %s\nstderr: %s' % (
50            self.result.command, self.result.stdout, self.result.stderr)
51
52
53_Tunnel = collections.namedtuple('_Tunnel',
54                                 ['local_port', 'remote_port', 'proc'])
55
56
57class SshConnection(object):
58    """Provides a connection to a remote machine through ssh.
59
60    Provides the ability to connect to a remote machine and execute a command
61    on it. The connection will try to establish a persistent connection When
62    a command is run. If the persistent connection fails it will attempt
63    to connect normally.
64    """
65
66    @property
67    def socket_path(self):
68        """Returns: The os path to the master socket file."""
69        return os.path.join(self._master_ssh_tempdir, 'socket')
70
71    def __init__(self, settings):
72        """
73        Args:
74            settings: The ssh settings to use for this connection.
75            formatter: The object that will handle formatting ssh command
76                       for use with the background job.
77        """
78        self._settings = settings
79        self._formatter = formatter.SshFormatter()
80        self._lock = threading.Lock()
81        self._master_ssh_proc = None
82        self._master_ssh_tempdir = None
83        self._tunnels = list()
84
85        def log_line(msg):
86            return '[SshConnection | %s] %s' % (self._settings.hostname, msg)
87
88        self.log = logger.create_logger(log_line)
89
90    def __enter__(self):
91        return self
92
93    def __exit__(self, _, __, ___):
94        self.close()
95
96    def __del__(self):
97        self.close()
98
99    def setup_master_ssh(self, timeout_seconds=5):
100        """Sets up the master ssh connection.
101
102        Sets up the initial master ssh connection if it has not already been
103        started.
104
105        Args:
106            timeout_seconds: The time to wait for the master ssh connection to
107            be made.
108
109        Raises:
110            Error: When setting up the master ssh connection fails.
111        """
112        with self._lock:
113            if self._master_ssh_proc is not None:
114                socket_path = self.socket_path
115                if (not os.path.exists(socket_path)
116                        or self._master_ssh_proc.poll() is not None):
117                    self.log.debug('Master ssh connection to %s is down.',
118                                   self._settings.hostname)
119                    self._cleanup_master_ssh()
120
121            if self._master_ssh_proc is None:
122                # Create a shared socket in a temp location.
123                self._master_ssh_tempdir = tempfile.mkdtemp(prefix='ssh-master')
124
125                # Setup flags and options for running the master ssh
126                # -N: Do not execute a remote command.
127                # ControlMaster: Spawn a master connection.
128                # ControlPath: The master connection socket path.
129                extra_flags = {'-N': None}
130                extra_options = {
131                    'ControlMaster': True,
132                    'ControlPath': self.socket_path,
133                    'BatchMode': True
134                }
135
136                # Construct the command and start it.
137                master_cmd = self._formatter.format_ssh_local_command(
138                    self._settings,
139                    extra_flags=extra_flags,
140                    extra_options=extra_options)
141                self.log.info('Starting master ssh connection.')
142                self._master_ssh_proc = job.run_async(master_cmd)
143
144                end_time = time.time() + timeout_seconds
145
146                while time.time() < end_time:
147                    if os.path.exists(self.socket_path):
148                        break
149                    time.sleep(.2)
150                else:
151                    self._cleanup_master_ssh()
152                    raise Error('Master ssh connection timed out.')
153
154    def run(self,
155            command,
156            timeout=3600,
157            ignore_status=False,
158            env=None,
159            io_encoding='utf-8',
160            attempts=2):
161        """Runs a remote command over ssh.
162
163        Will ssh to a remote host and run a command. This method will
164        block until the remote command is finished.
165
166        Args:
167            command: The command to execute over ssh. Can be either a string
168                     or a list.
169            timeout: number seconds to wait for command to finish.
170            ignore_status: bool True to ignore the exit code of the remote
171                           subprocess.  Note that if you do ignore status codes,
172                           you should handle non-zero exit codes explicitly.
173            env: dict environment variables to setup on the remote host.
174            io_encoding: str unicode encoding of command output.
175            attempts: Number of attempts before giving up on command failures.
176
177        Returns:
178            A job.Result containing the results of the ssh command.
179
180        Raises:
181            job.TimeoutError: When the remote command took to long to execute.
182            Error: When the ssh connection failed to be created.
183            CommandError: Ssh worked, but the command had an error executing.
184        """
185        if attempts == 0:
186            return None
187        if env is None:
188            env = {}
189
190        try:
191            self.setup_master_ssh(self._settings.connect_timeout)
192        except Error:
193            self.log.warning('Failed to create master ssh connection, using '
194                             'normal ssh connection.')
195
196        extra_options = {'BatchMode': True}
197        if self._master_ssh_proc:
198            extra_options['ControlPath'] = self.socket_path
199
200        identifier = str(uuid.uuid4())
201        full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
202
203        terminal_command = self._formatter.format_command(
204            full_command, env, self._settings, extra_options=extra_options)
205
206        dns_retry_count = 2
207        while True:
208            result = job.run(
209                terminal_command,
210                ignore_status=True,
211                timeout=timeout,
212                io_encoding=io_encoding)
213            output = result.stdout
214
215            # Check for a connected message to prevent false negatives.
216            valid_connection = re.search(
217                '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE)
218            if valid_connection:
219                # Remove the first line that contains the connect message.
220                line_index = output.find('\n') + 1
221                if line_index == 0:
222                    line_index = len(output)
223                real_output = output[line_index:].encode(io_encoding)
224
225                result = job.Result(
226                    command=result.command,
227                    stdout=real_output,
228                    stderr=result._raw_stderr,
229                    exit_status=result.exit_status,
230                    duration=result.duration,
231                    did_timeout=result.did_timeout,
232                    encoding=io_encoding)
233                if result.exit_status and not ignore_status:
234                    raise job.Error(result)
235                return result
236
237            error_string = result.stderr
238
239            had_dns_failure = (result.exit_status == 255 and re.search(
240                r'^ssh: .*: Name or service not known',
241                error_string,
242                flags=re.MULTILINE))
243            if had_dns_failure:
244                dns_retry_count -= 1
245                if not dns_retry_count:
246                    raise Error('DNS failed to find host.', result)
247                self.log.debug('Failed to connect to host, retrying...')
248            else:
249                break
250
251        had_timeout = re.search(
252            r'^ssh: connect to host .* port .*: '
253            r'Connection timed out\r$',
254            error_string,
255            flags=re.MULTILINE)
256        if had_timeout:
257            raise Error('Ssh timed out.', result)
258
259        permission_denied = 'Permission denied' in error_string
260        if permission_denied:
261            raise Error('Permission denied.', result)
262
263        unknown_host = re.search(
264            r'ssh: Could not resolve hostname .*: '
265            r'Name or service not known',
266            error_string,
267            flags=re.MULTILINE)
268        if unknown_host:
269            raise Error('Unknown host.', result)
270
271        self.log.error('An unknown error has occurred. Job result: %s' % result)
272        ping_output = job.run(
273            'ping %s -c 3 -w 1' % self._settings.hostname, ignore_status=True)
274        self.log.error('Ping result: %s' % ping_output)
275        if attempts > 1:
276            self._cleanup_master_ssh()
277            self.run(command, timeout, ignore_status, env, io_encoding,
278                     attempts - 1)
279        raise Error('The job failed for unknown reasons.', result)
280
281    def run_async(self, command, env=None):
282        """Starts up a background command over ssh.
283
284        Will ssh to a remote host and startup a command. This method will
285        block until there is confirmation that the remote command has started.
286
287        Args:
288            command: The command to execute over ssh. Can be either a string
289                     or a list.
290            env: A dictonary of environment variables to setup on the remote
291                 host.
292
293        Returns:
294            The result of the command to launch the background job.
295
296        Raises:
297            CmdTimeoutError: When the remote command took to long to execute.
298            SshTimeoutError: When the connection took to long to established.
299            SshPermissionDeniedError: When permission is not allowed on the
300                                      remote host.
301        """
302        command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
303        result = self.run(command, env=env)
304        return result
305
306    def close(self):
307        """Clean up open connections to remote host."""
308        self._cleanup_master_ssh()
309        while self._tunnels:
310            self.close_ssh_tunnel(self._tunnels[0].local_port)
311
312    def _cleanup_master_ssh(self):
313        """
314        Release all resources (process, temporary directory) used by an active
315        master SSH connection.
316        """
317        # If a master SSH connection is running, kill it.
318        if self._master_ssh_proc is not None:
319            self.log.debug('Nuking master_ssh_job.')
320            self._master_ssh_proc.kill()
321            self._master_ssh_proc.wait()
322            self._master_ssh_proc = None
323
324        # Remove the temporary directory for the master SSH socket.
325        if self._master_ssh_tempdir is not None:
326            self.log.debug('Cleaning master_ssh_tempdir.')
327            shutil.rmtree(self._master_ssh_tempdir)
328            self._master_ssh_tempdir = None
329
330    def create_ssh_tunnel(self, port, local_port=None):
331        """Create an ssh tunnel from local_port to port.
332
333        This securely forwards traffic from local_port on this machine to the
334        remote SSH host at port.
335
336        Args:
337            port: remote port on the host.
338            local_port: local forwarding port, or None to pick an available
339                        port.
340
341        Returns:
342            the created tunnel process.
343        """
344        if not local_port:
345            local_port = host_utils.get_available_host_port()
346        else:
347            for tunnel in self._tunnels:
348                if tunnel.remote_port == port:
349                    return tunnel.local_port
350
351        extra_flags = {
352            '-n': None,  # Read from /dev/null for stdin
353            '-N': None,  # Do not execute a remote command
354            '-q': None,  # Suppress warnings and diagnostic commands
355            '-L': '%d:localhost:%d' % (local_port, port),
356        }
357        extra_options = dict()
358        if self._master_ssh_proc:
359            extra_options['ControlPath'] = self.socket_path
360        tunnel_cmd = self._formatter.format_ssh_local_command(
361            self._settings,
362            extra_flags=extra_flags,
363            extra_options=extra_options)
364        self.log.debug('Full tunnel command: %s', tunnel_cmd)
365        # Exec the ssh process directly so that when we deliver signals, we
366        # deliver them straight to the child process.
367        tunnel_proc = job.run_async(tunnel_cmd)
368        self.log.debug('Started ssh tunnel, local = %d remote = %d, pid = %d',
369                       local_port, port, tunnel_proc.pid)
370        self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
371        return local_port
372
373    def close_ssh_tunnel(self, local_port):
374        """Close a previously created ssh tunnel of a TCP port.
375
376        Args:
377            local_port: int port on localhost previously forwarded to the remote
378                        host.
379
380        Returns:
381            integer port number this port was forwarded to on the remote host or
382            None if no tunnel was found.
383        """
384        idx = None
385        for i, tunnel in enumerate(self._tunnels):
386            if tunnel.local_port == local_port:
387                idx = i
388                break
389        if idx is not None:
390            tunnel = self._tunnels.pop(idx)
391            tunnel.proc.kill()
392            tunnel.proc.wait()
393            return tunnel.remote_port
394        return None
395
396    def send_file(self, local_path, remote_path, ignore_status=False):
397        """Send a file from the local host to the remote host.
398
399        Args:
400            local_path: string path of file to send on local host.
401            remote_path: string path to copy file to on remote host.
402            ignore_status: Whether or not to ignore the command's exit_status.
403        """
404        # TODO: This may belong somewhere else: b/32572515
405        user_host = self._formatter.format_host_name(self._settings)
406        job.run(
407            'scp %s %s:%s' % (local_path, user_host, remote_path),
408            ignore_status=ignore_status)
409
410    def pull_file(self, local_path, remote_path, ignore_status=False):
411        """Send a file from remote host to local host
412
413        Args:
414            local_path: string path of file to recv on local host
415            remote_path: string path to copy file from on remote host.
416            ignore_status: Whether or not to ignore the command's exit_status.
417        """
418        user_host = self._formatter.format_host_name(self._settings)
419        job.run(
420            'scp %s:%s %s' % (user_host, remote_path, local_path),
421            ignore_status=ignore_status)
422
423    def find_free_port(self, interface_name='localhost'):
424        """Find a unused port on the remote host.
425
426        Note that this method is inherently racy, since it is impossible
427        to promise that the remote port will remain free.
428
429        Args:
430            interface_name: string name of interface to check whether a
431                            port is used against.
432
433        Returns:
434            integer port number on remote interface that was free.
435        """
436        # TODO: This may belong somewhere else: b/3257251
437        free_port_cmd = (
438            'python -c "import socket; s=socket.socket(); '
439            's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
440        ) % interface_name
441        port = int(self.run(free_port_cmd).stdout)
442        # Yield to the os to ensure the port gets cleaned up.
443        time.sleep(0.001)
444        return port
445