1# Copyright 2016 - The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import os 17import re 18import shutil 19import tempfile 20import threading 21import time 22import uuid 23 24from acts import logger 25from acts.controllers.utils_lib import host_utils 26from acts.controllers.utils_lib.ssh import formatter 27from acts.libs.proc import job 28 29 30class Error(Exception): 31 """An error occurred during an ssh operation.""" 32 33 34class CommandError(Exception): 35 """An error occurred with the command. 36 37 Attributes: 38 result: The results of the ssh command that had the error. 39 """ 40 41 def __init__(self, result): 42 """ 43 Args: 44 result: The result of the ssh command that created the problem. 45 """ 46 self.result = result 47 48 def __str__(self): 49 return 'cmd: %s\nstdout: %s\nstderr: %s' % (self.result.command, 50 self.result.stdout, 51 self.result.stderr) 52 53 54_Tunnel = collections.namedtuple('_Tunnel', 55 ['local_port', 'remote_port', 'proc']) 56 57 58class SshConnection(object): 59 """Provides a connection to a remote machine through ssh. 60 61 Provides the ability to connect to a remote machine and execute a command 62 on it. The connection will try to establish a persistent connection When 63 a command is run. If the persistent connection fails it will attempt 64 to connect normally. 65 """ 66 67 @property 68 def socket_path(self): 69 """Returns: The os path to the master socket file.""" 70 return os.path.join(self._master_ssh_tempdir, 'socket') 71 72 def __init__(self, settings): 73 """ 74 Args: 75 settings: The ssh settings to use for this connection. 76 formatter: The object that will handle formatting ssh command 77 for use with the background job. 78 """ 79 self._settings = settings 80 self._formatter = formatter.SshFormatter() 81 self._lock = threading.Lock() 82 self._master_ssh_proc = None 83 self._master_ssh_tempdir = None 84 self._tunnels = list() 85 86 def log_line(msg): 87 return '[SshConnection | %s] %s' % (self._settings.hostname, msg) 88 89 self.log = logger.create_logger(log_line) 90 91 def __del__(self): 92 self.close() 93 94 def setup_master_ssh(self, timeout_seconds=5): 95 """Sets up the master ssh connection. 96 97 Sets up the initial master ssh connection if it has not already been 98 started. 99 100 Args: 101 timeout_seconds: The time to wait for the master ssh connection to 102 be made. 103 104 Raises: 105 Error: When setting up the master ssh connection fails. 106 """ 107 with self._lock: 108 if self._master_ssh_proc is not None: 109 socket_path = self.socket_path 110 if (not os.path.exists(socket_path) or 111 self._master_ssh_proc.poll() is not None): 112 self.log.debug('Master ssh connection to %s is down.', 113 self._settings.hostname) 114 self._cleanup_master_ssh() 115 116 if self._master_ssh_proc is None: 117 # Create a shared socket in a temp location. 118 self._master_ssh_tempdir = tempfile.mkdtemp( 119 prefix='ssh-master') 120 121 # Setup flags and options for running the master ssh 122 # -N: Do not execute a remote command. 123 # ControlMaster: Spawn a master connection. 124 # ControlPath: The master connection socket path. 125 extra_flags = {'-N': None} 126 extra_options = { 127 'ControlMaster': True, 128 'ControlPath': self.socket_path, 129 'BatchMode': True 130 } 131 132 # Construct the command and start it. 133 master_cmd = self._formatter.format_ssh_local_command( 134 self._settings, 135 extra_flags=extra_flags, 136 extra_options=extra_options) 137 self.log.info('Starting master ssh connection.') 138 self._master_ssh_proc = job.run_async(master_cmd) 139 140 end_time = time.time() + timeout_seconds 141 142 while time.time() < end_time: 143 if os.path.exists(self.socket_path): 144 break 145 time.sleep(.2) 146 else: 147 self._cleanup_master_ssh() 148 raise Error('Master ssh connection timed out.') 149 150 def run(self, 151 command, 152 timeout=3600, 153 ignore_status=False, 154 env=None, 155 io_encoding='utf-8', 156 attempts=2): 157 """Runs a remote command over ssh. 158 159 Will ssh to a remote host and run a command. This method will 160 block until the remote command is finished. 161 162 Args: 163 command: The command to execute over ssh. Can be either a string 164 or a list. 165 timeout: number seconds to wait for command to finish. 166 ignore_status: bool True to ignore the exit code of the remote 167 subprocess. Note that if you do ignore status codes, 168 you should handle non-zero exit codes explicitly. 169 env: dict environment variables to setup on the remote host. 170 io_encoding: str unicode encoding of command output. 171 attempts: Number of attempts before giving up on command failures. 172 173 Returns: 174 A job.Result containing the results of the ssh command. 175 176 Raises: 177 job.TimeoutError: When the remote command took to long to execute. 178 Error: When the ssh connection failed to be created. 179 CommandError: Ssh worked, but the command had an error executing. 180 """ 181 if attempts == 0: 182 return None 183 if env is None: 184 env = {} 185 186 try: 187 self.setup_master_ssh(self._settings.connect_timeout) 188 except Error: 189 self.log.warning('Failed to create master ssh connection, using ' 190 'normal ssh connection.') 191 192 extra_options = {'BatchMode': True} 193 if self._master_ssh_proc: 194 extra_options['ControlPath'] = self.socket_path 195 196 identifier = str(uuid.uuid4()) 197 full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command) 198 199 terminal_command = self._formatter.format_command( 200 full_command, env, self._settings, extra_options=extra_options) 201 202 dns_retry_count = 2 203 while True: 204 result = job.run( 205 terminal_command, ignore_status=True, timeout=timeout) 206 output = result.stdout 207 208 # Check for a connected message to prevent false negatives. 209 valid_connection = re.search( 210 '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE) 211 if valid_connection: 212 # Remove the first line that contains the connect message. 213 line_index = output.find('\n') 214 real_output = output[line_index + 1:].encode(result._encoding) 215 216 result = job.Result( 217 command=result.command, 218 stdout=real_output, 219 stderr=result._raw_stderr, 220 exit_status=result.exit_status, 221 duration=result.duration, 222 did_timeout=result.did_timeout, 223 encoding=result._encoding) 224 if result.exit_status and not ignore_status: 225 raise job.Error(result) 226 return result 227 228 error_string = result.stderr 229 230 had_dns_failure = (result.exit_status == 255 and re.search( 231 r'^ssh: .*: Name or service not known', 232 error_string, 233 flags=re.MULTILINE)) 234 if had_dns_failure: 235 dns_retry_count -= 1 236 if not dns_retry_count: 237 raise Error('DNS failed to find host.', result) 238 self.log.debug('Failed to connect to host, retrying...') 239 else: 240 break 241 242 had_timeout = re.search( 243 r'^ssh: connect to host .* port .*: ' 244 r'Connection timed out\r$', 245 error_string, 246 flags=re.MULTILINE) 247 if had_timeout: 248 raise Error('Ssh timed out.', result) 249 250 permission_denied = 'Permission denied' in error_string 251 if permission_denied: 252 raise Error('Permission denied.', result) 253 254 unknown_host = re.search( 255 r'ssh: Could not resolve hostname .*: ' 256 r'Name or service not known', 257 error_string, 258 flags=re.MULTILINE) 259 if unknown_host: 260 raise Error('Unknown host.', result) 261 262 self.log.error('An unknown error has occurred. Job result: %s' % result) 263 ping_output = job.run( 264 'ping %s -c 3 -w 1' % self._settings.hostname, ignore_status=True) 265 self.log.error('Ping result: %s' % ping_output) 266 if attempts > 1: 267 self._cleanup_master_ssh() 268 self.run(command, timeout, ignore_status, env, io_encoding, 269 attempts - 1) 270 raise Error('The job failed for unknown reasons.', result) 271 272 def run_async(self, command, env=None): 273 """Starts up a background command over ssh. 274 275 Will ssh to a remote host and startup a command. This method will 276 block until there is confirmation that the remote command has started. 277 278 Args: 279 command: The command to execute over ssh. Can be either a string 280 or a list. 281 env: A dictonary of environment variables to setup on the remote 282 host. 283 284 Returns: 285 The result of the command to launch the background job. 286 287 Raises: 288 CmdTimeoutError: When the remote command took to long to execute. 289 SshTimeoutError: When the connection took to long to established. 290 SshPermissionDeniedError: When permission is not allowed on the 291 remote host. 292 """ 293 command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command 294 result = self.run(command, env=env) 295 return result 296 297 def close(self): 298 """Clean up open connections to remote host.""" 299 self._cleanup_master_ssh() 300 while self._tunnels: 301 self.close_ssh_tunnel(self._tunnels[0].local_port) 302 303 def _cleanup_master_ssh(self): 304 """ 305 Release all resources (process, temporary directory) used by an active 306 master SSH connection. 307 """ 308 # If a master SSH connection is running, kill it. 309 if self._master_ssh_proc is not None: 310 self.log.debug('Nuking master_ssh_job.') 311 self._master_ssh_proc.kill() 312 self._master_ssh_proc.wait() 313 self._master_ssh_proc = None 314 315 # Remove the temporary directory for the master SSH socket. 316 if self._master_ssh_tempdir is not None: 317 self.log.debug('Cleaning master_ssh_tempdir.') 318 shutil.rmtree(self._master_ssh_tempdir) 319 self._master_ssh_tempdir = None 320 321 def create_ssh_tunnel(self, port, local_port=None): 322 """Create an ssh tunnel from local_port to port. 323 324 This securely forwards traffic from local_port on this machine to the 325 remote SSH host at port. 326 327 Args: 328 port: remote port on the host. 329 local_port: local forwarding port, or None to pick an available 330 port. 331 332 Returns: 333 the created tunnel process. 334 """ 335 if local_port is None: 336 local_port = host_utils.get_available_host_port() 337 else: 338 for tunnel in self._tunnels: 339 if tunnel.remote_port == port: 340 return tunnel.local_port 341 342 extra_flags = { 343 '-n': None, # Read from /dev/null for stdin 344 '-N': None, # Do not execute a remote command 345 '-q': None, # Suppress warnings and diagnostic commands 346 '-L': '%d:localhost:%d' % (local_port, port), 347 } 348 extra_options = dict() 349 if self._master_ssh_proc: 350 extra_options['ControlPath'] = self.socket_path 351 tunnel_cmd = self._formatter.format_ssh_local_command( 352 self._settings, 353 extra_flags=extra_flags, 354 extra_options=extra_options) 355 self.log.debug('Full tunnel command: %s', tunnel_cmd) 356 # Exec the ssh process directly so that when we deliver signals, we 357 # deliver them straight to the child process. 358 tunnel_proc = job.run_async(tunnel_cmd) 359 self.log.debug('Started ssh tunnel, local = %d' 360 ' remote = %d, pid = %d', local_port, port, 361 tunnel_proc.pid) 362 self._tunnels.append(_Tunnel(local_port, port, tunnel_proc)) 363 return local_port 364 365 def close_ssh_tunnel(self, local_port): 366 """Close a previously created ssh tunnel of a TCP port. 367 368 Args: 369 local_port: int port on localhost previously forwarded to the remote 370 host. 371 372 Returns: 373 integer port number this port was forwarded to on the remote host or 374 None if no tunnel was found. 375 """ 376 idx = None 377 for i, tunnel in enumerate(self._tunnels): 378 if tunnel.local_port == local_port: 379 idx = i 380 break 381 if idx is not None: 382 tunnel = self._tunnels.pop(idx) 383 tunnel.proc.kill() 384 tunnel.proc.wait() 385 return tunnel.remote_port 386 return None 387 388 def send_file(self, local_path, remote_path): 389 """Send a file from the local host to the remote host. 390 391 Args: 392 local_path: string path of file to send on local host. 393 remote_path: string path to copy file to on remote host. 394 """ 395 # TODO: This may belong somewhere else: b/32572515 396 user_host = self._formatter.format_host_name(self._settings) 397 job.run('scp %s %s:%s' % (local_path, user_host, remote_path)) 398 399 def find_free_port(self, interface_name='localhost'): 400 """Find a unused port on the remote host. 401 402 Note that this method is inherently racy, since it is impossible 403 to promise that the remote port will remain free. 404 405 Args: 406 interface_name: string name of interface to check whether a 407 port is used against. 408 409 Returns: 410 integer port number on remote interface that was free. 411 """ 412 # TODO: This may belong somewhere else: b/3257251 413 free_port_cmd = ( 414 'python -c "import socket; s=socket.socket(); ' 415 's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"' 416 ) % interface_name 417 port = int(self.run(free_port_cmd).stdout) 418 # Yield to the os to ensure the port gets cleaned up. 419 time.sleep(0.001) 420 return port 421