1# Copyright 2016 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import os
17import re
18import shutil
19import tempfile
20import threading
21import time
22import uuid
23
24from acts import logger
25from acts.controllers.utils_lib import host_utils
26from acts.controllers.utils_lib.ssh import formatter
27from acts.libs.proc import job
28
29
30class Error(Exception):
31    """An error occurred during an ssh operation."""
32
33
34class CommandError(Exception):
35    """An error occurred with the command.
36
37    Attributes:
38        result: The results of the ssh command that had the error.
39    """
40
41    def __init__(self, result):
42        """
43        Args:
44            result: The result of the ssh command that created the problem.
45        """
46        self.result = result
47
48    def __str__(self):
49        return 'cmd: %s\nstdout: %s\nstderr: %s' % (self.result.command,
50                                                    self.result.stdout,
51                                                    self.result.stderr)
52
53
54_Tunnel = collections.namedtuple('_Tunnel',
55                                 ['local_port', 'remote_port', 'proc'])
56
57
58class SshConnection(object):
59    """Provides a connection to a remote machine through ssh.
60
61    Provides the ability to connect to a remote machine and execute a command
62    on it. The connection will try to establish a persistent connection When
63    a command is run. If the persistent connection fails it will attempt
64    to connect normally.
65    """
66
67    @property
68    def socket_path(self):
69        """Returns: The os path to the master socket file."""
70        return os.path.join(self._master_ssh_tempdir, 'socket')
71
72    def __init__(self, settings):
73        """
74        Args:
75            settings: The ssh settings to use for this connection.
76            formatter: The object that will handle formatting ssh command
77                       for use with the background job.
78        """
79        self._settings = settings
80        self._formatter = formatter.SshFormatter()
81        self._lock = threading.Lock()
82        self._master_ssh_proc = None
83        self._master_ssh_tempdir = None
84        self._tunnels = list()
85
86        def log_line(msg):
87            return '[SshConnection | %s] %s' % (self._settings.hostname, msg)
88
89        self.log = logger.create_logger(log_line)
90
91    def __del__(self):
92        self.close()
93
94    def setup_master_ssh(self, timeout_seconds=5):
95        """Sets up the master ssh connection.
96
97        Sets up the initial master ssh connection if it has not already been
98        started.
99
100        Args:
101            timeout_seconds: The time to wait for the master ssh connection to
102            be made.
103
104        Raises:
105            Error: When setting up the master ssh connection fails.
106        """
107        with self._lock:
108            if self._master_ssh_proc is not None:
109                socket_path = self.socket_path
110                if (not os.path.exists(socket_path) or
111                        self._master_ssh_proc.poll() is not None):
112                    self.log.debug('Master ssh connection to %s is down.',
113                                   self._settings.hostname)
114                    self._cleanup_master_ssh()
115
116            if self._master_ssh_proc is None:
117                # Create a shared socket in a temp location.
118                self._master_ssh_tempdir = tempfile.mkdtemp(
119                    prefix='ssh-master')
120
121                # Setup flags and options for running the master ssh
122                # -N: Do not execute a remote command.
123                # ControlMaster: Spawn a master connection.
124                # ControlPath: The master connection socket path.
125                extra_flags = {'-N': None}
126                extra_options = {
127                    'ControlMaster': True,
128                    'ControlPath': self.socket_path,
129                    'BatchMode': True
130                }
131
132                # Construct the command and start it.
133                master_cmd = self._formatter.format_ssh_local_command(
134                    self._settings,
135                    extra_flags=extra_flags,
136                    extra_options=extra_options)
137                self.log.info('Starting master ssh connection.')
138                self._master_ssh_proc = job.run_async(master_cmd)
139
140                end_time = time.time() + timeout_seconds
141
142                while time.time() < end_time:
143                    if os.path.exists(self.socket_path):
144                        break
145                    time.sleep(.2)
146                else:
147                    self._cleanup_master_ssh()
148                    raise Error('Master ssh connection timed out.')
149
150    def run(self,
151            command,
152            timeout=3600,
153            ignore_status=False,
154            env=None,
155            io_encoding='utf-8',
156            attempts=2):
157        """Runs a remote command over ssh.
158
159        Will ssh to a remote host and run a command. This method will
160        block until the remote command is finished.
161
162        Args:
163            command: The command to execute over ssh. Can be either a string
164                     or a list.
165            timeout: number seconds to wait for command to finish.
166            ignore_status: bool True to ignore the exit code of the remote
167                           subprocess.  Note that if you do ignore status codes,
168                           you should handle non-zero exit codes explicitly.
169            env: dict environment variables to setup on the remote host.
170            io_encoding: str unicode encoding of command output.
171            attempts: Number of attempts before giving up on command failures.
172
173        Returns:
174            A job.Result containing the results of the ssh command.
175
176        Raises:
177            job.TimeoutError: When the remote command took to long to execute.
178            Error: When the ssh connection failed to be created.
179            CommandError: Ssh worked, but the command had an error executing.
180        """
181        if attempts == 0:
182            return None
183        if env is None:
184            env = {}
185
186        try:
187            self.setup_master_ssh(self._settings.connect_timeout)
188        except Error:
189            self.log.warning('Failed to create master ssh connection, using '
190                             'normal ssh connection.')
191
192        extra_options = {'BatchMode': True}
193        if self._master_ssh_proc:
194            extra_options['ControlPath'] = self.socket_path
195
196        identifier = str(uuid.uuid4())
197        full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
198
199        terminal_command = self._formatter.format_command(
200            full_command, env, self._settings, extra_options=extra_options)
201
202        dns_retry_count = 2
203        while True:
204            result = job.run(
205                terminal_command, ignore_status=True, timeout=timeout)
206            output = result.stdout
207
208            # Check for a connected message to prevent false negatives.
209            valid_connection = re.search(
210                '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE)
211            if valid_connection:
212                # Remove the first line that contains the connect message.
213                line_index = output.find('\n')
214                real_output = output[line_index + 1:].encode(result._encoding)
215
216                result = job.Result(
217                    command=result.command,
218                    stdout=real_output,
219                    stderr=result._raw_stderr,
220                    exit_status=result.exit_status,
221                    duration=result.duration,
222                    did_timeout=result.did_timeout,
223                    encoding=result._encoding)
224                if result.exit_status and not ignore_status:
225                    raise job.Error(result)
226                return result
227
228            error_string = result.stderr
229
230            had_dns_failure = (result.exit_status == 255 and re.search(
231                r'^ssh: .*: Name or service not known',
232                error_string,
233                flags=re.MULTILINE))
234            if had_dns_failure:
235                dns_retry_count -= 1
236                if not dns_retry_count:
237                    raise Error('DNS failed to find host.', result)
238                self.log.debug('Failed to connect to host, retrying...')
239            else:
240                break
241
242        had_timeout = re.search(
243            r'^ssh: connect to host .* port .*: '
244            r'Connection timed out\r$',
245            error_string,
246            flags=re.MULTILINE)
247        if had_timeout:
248            raise Error('Ssh timed out.', result)
249
250        permission_denied = 'Permission denied' in error_string
251        if permission_denied:
252            raise Error('Permission denied.', result)
253
254        unknown_host = re.search(
255            r'ssh: Could not resolve hostname .*: '
256            r'Name or service not known',
257            error_string,
258            flags=re.MULTILINE)
259        if unknown_host:
260            raise Error('Unknown host.', result)
261
262        self.log.error('An unknown error has occurred. Job result: %s' % result)
263        ping_output = job.run(
264            'ping %s -c 3 -w 1' % self._settings.hostname, ignore_status=True)
265        self.log.error('Ping result: %s' % ping_output)
266        if attempts > 1:
267            self._cleanup_master_ssh()
268            self.run(command, timeout, ignore_status, env, io_encoding,
269                     attempts - 1)
270        raise Error('The job failed for unknown reasons.', result)
271
272    def run_async(self, command, env=None):
273        """Starts up a background command over ssh.
274
275        Will ssh to a remote host and startup a command. This method will
276        block until there is confirmation that the remote command has started.
277
278        Args:
279            command: The command to execute over ssh. Can be either a string
280                     or a list.
281            env: A dictonary of environment variables to setup on the remote
282                 host.
283
284        Returns:
285            The result of the command to launch the background job.
286
287        Raises:
288            CmdTimeoutError: When the remote command took to long to execute.
289            SshTimeoutError: When the connection took to long to established.
290            SshPermissionDeniedError: When permission is not allowed on the
291                                      remote host.
292        """
293        command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
294        result = self.run(command, env=env)
295        return result
296
297    def close(self):
298        """Clean up open connections to remote host."""
299        self._cleanup_master_ssh()
300        while self._tunnels:
301            self.close_ssh_tunnel(self._tunnels[0].local_port)
302
303    def _cleanup_master_ssh(self):
304        """
305        Release all resources (process, temporary directory) used by an active
306        master SSH connection.
307        """
308        # If a master SSH connection is running, kill it.
309        if self._master_ssh_proc is not None:
310            self.log.debug('Nuking master_ssh_job.')
311            self._master_ssh_proc.kill()
312            self._master_ssh_proc.wait()
313            self._master_ssh_proc = None
314
315        # Remove the temporary directory for the master SSH socket.
316        if self._master_ssh_tempdir is not None:
317            self.log.debug('Cleaning master_ssh_tempdir.')
318            shutil.rmtree(self._master_ssh_tempdir)
319            self._master_ssh_tempdir = None
320
321    def create_ssh_tunnel(self, port, local_port=None):
322        """Create an ssh tunnel from local_port to port.
323
324        This securely forwards traffic from local_port on this machine to the
325        remote SSH host at port.
326
327        Args:
328            port: remote port on the host.
329            local_port: local forwarding port, or None to pick an available
330                        port.
331
332        Returns:
333            the created tunnel process.
334        """
335        if local_port is None:
336            local_port = host_utils.get_available_host_port()
337        else:
338            for tunnel in self._tunnels:
339                if tunnel.remote_port == port:
340                    return tunnel.local_port
341
342        extra_flags = {
343            '-n': None,  # Read from /dev/null for stdin
344            '-N': None,  # Do not execute a remote command
345            '-q': None,  # Suppress warnings and diagnostic commands
346            '-L': '%d:localhost:%d' % (local_port, port),
347        }
348        extra_options = dict()
349        if self._master_ssh_proc:
350            extra_options['ControlPath'] = self.socket_path
351        tunnel_cmd = self._formatter.format_ssh_local_command(
352            self._settings,
353            extra_flags=extra_flags,
354            extra_options=extra_options)
355        self.log.debug('Full tunnel command: %s', tunnel_cmd)
356        # Exec the ssh process directly so that when we deliver signals, we
357        # deliver them straight to the child process.
358        tunnel_proc = job.run_async(tunnel_cmd)
359        self.log.debug('Started ssh tunnel, local = %d'
360                       ' remote = %d, pid = %d', local_port, port,
361                       tunnel_proc.pid)
362        self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
363        return local_port
364
365    def close_ssh_tunnel(self, local_port):
366        """Close a previously created ssh tunnel of a TCP port.
367
368        Args:
369            local_port: int port on localhost previously forwarded to the remote
370                        host.
371
372        Returns:
373            integer port number this port was forwarded to on the remote host or
374            None if no tunnel was found.
375        """
376        idx = None
377        for i, tunnel in enumerate(self._tunnels):
378            if tunnel.local_port == local_port:
379                idx = i
380                break
381        if idx is not None:
382            tunnel = self._tunnels.pop(idx)
383            tunnel.proc.kill()
384            tunnel.proc.wait()
385            return tunnel.remote_port
386        return None
387
388    def send_file(self, local_path, remote_path):
389        """Send a file from the local host to the remote host.
390
391        Args:
392            local_path: string path of file to send on local host.
393            remote_path: string path to copy file to on remote host.
394        """
395        # TODO: This may belong somewhere else: b/32572515
396        user_host = self._formatter.format_host_name(self._settings)
397        job.run('scp %s %s:%s' % (local_path, user_host, remote_path))
398
399    def find_free_port(self, interface_name='localhost'):
400        """Find a unused port on the remote host.
401
402        Note that this method is inherently racy, since it is impossible
403        to promise that the remote port will remain free.
404
405        Args:
406            interface_name: string name of interface to check whether a
407                            port is used against.
408
409        Returns:
410            integer port number on remote interface that was free.
411        """
412        # TODO: This may belong somewhere else: b/3257251
413        free_port_cmd = (
414                            'python -c "import socket; s=socket.socket(); '
415                            's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
416                        ) % interface_name
417        port = int(self.run(free_port_cmd).stdout)
418        # Yield to the os to ensure the port gets cleaned up.
419        time.sleep(0.001)
420        return port
421