1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A wrapper around ssh for common operations on a CrOS-based device"""
5import logging
6import os
7import re
8import shutil
9import stat
10import subprocess
11import tempfile
12
13# Some developers' workflow includes running the Chrome process from
14# /usr/local/... instead of the default location. We have to check for both
15# paths in order to support this workflow.
16_CHROME_PROCESS_REGEX = [re.compile(r'^/opt/google/chrome/chrome '),
17                         re.compile(r'^/usr/local/?.*/chrome/chrome ')]
18
19
20def RunCmd(args, cwd=None, quiet=False):
21  """Opens a subprocess to execute a program and returns its return value.
22
23  Args:
24    args: A string or a sequence of program arguments. The program to execute is
25      the string or the first item in the args sequence.
26    cwd: If not None, the subprocess's current directory will be changed to
27      |cwd| before it's executed.
28
29  Returns:
30    Return code from the command execution.
31  """
32  if not quiet:
33    logging.debug(' '.join(args) + ' ' + (cwd or ''))
34  with open(os.devnull, 'w') as devnull:
35    p = subprocess.Popen(args=args,
36                         cwd=cwd,
37                         stdout=devnull,
38                         stderr=devnull,
39                         stdin=devnull,
40                         shell=False)
41    return p.wait()
42
43
44def GetAllCmdOutput(args, cwd=None, quiet=False):
45  """Open a subprocess to execute a program and returns its output.
46
47  Args:
48    args: A string or a sequence of program arguments. The program to execute is
49      the string or the first item in the args sequence.
50    cwd: If not None, the subprocess's current directory will be changed to
51      |cwd| before it's executed.
52
53  Returns:
54    Captures and returns the command's stdout.
55    Prints the command's stderr to logger (which defaults to stdout).
56  """
57  if not quiet:
58    logging.debug(' '.join(args) + ' ' + (cwd or ''))
59  with open(os.devnull, 'w') as devnull:
60    p = subprocess.Popen(args=args,
61                         cwd=cwd,
62                         stdout=subprocess.PIPE,
63                         stderr=subprocess.PIPE,
64                         stdin=devnull)
65    stdout, stderr = p.communicate()
66    if not quiet:
67      logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
68    return stdout, stderr
69
70
71def HasSSH():
72  try:
73    RunCmd(['ssh'], quiet=True)
74    RunCmd(['scp'], quiet=True)
75    logging.debug("HasSSH()->True")
76    return True
77  except OSError:
78    logging.debug("HasSSH()->False")
79    return False
80
81
82class LoginException(Exception):
83  pass
84
85
86class KeylessLoginRequiredException(LoginException):
87  pass
88
89
90class DNSFailureException(LoginException):
91  pass
92
93
94class CrOSInterface(object):
95
96  def __init__(self, hostname=None, ssh_port=None, ssh_identity=None):
97    self._hostname = hostname
98    self._ssh_port = ssh_port
99
100    # List of ports generated from GetRemotePort() that may not be in use yet.
101    self._reserved_ports = []
102
103    if self.local:
104      return
105
106    self._ssh_identity = None
107    self._ssh_args = ['-o ConnectTimeout=5', '-o StrictHostKeyChecking=no',
108                      '-o KbdInteractiveAuthentication=no',
109                      '-o PreferredAuthentications=publickey',
110                      '-o UserKnownHostsFile=/dev/null', '-o ControlMaster=no']
111
112    if ssh_identity:
113      self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
114      os.chmod(self._ssh_identity, stat.S_IREAD)
115
116    # Establish master SSH connection using ControlPersist.
117    # Since only one test will be run on a remote host at a time,
118    # the control socket filename can be telemetry@hostname.
119    self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname
120    with open(os.devnull, 'w') as devnull:
121      subprocess.call(
122          self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']),
123          stdin=devnull,
124          stdout=devnull,
125          stderr=devnull)
126
127  def __enter__(self):
128    return self
129
130  def __exit__(self, *args):
131    self.CloseConnection()
132
133  @property
134  def local(self):
135    return not self._hostname
136
137  @property
138  def hostname(self):
139    return self._hostname
140
141  @property
142  def ssh_port(self):
143    return self._ssh_port
144
145  def FormSSHCommandLine(self, args, extra_ssh_args=None):
146    """Constructs a subprocess-suitable command line for `ssh'.
147    """
148    if self.local:
149      # We run the command through the shell locally for consistency with
150      # how commands are run through SSH (crbug.com/239161). This work
151      # around will be unnecessary once we implement a persistent SSH
152      # connection to run remote commands (crbug.com/239607).
153      return ['sh', '-c', " ".join(args)]
154
155    full_args = ['ssh', '-o ForwardX11=no', '-o ForwardX11Trusted=no', '-n',
156                 '-S', self._ssh_control_file] + self._ssh_args
157    if self._ssh_identity is not None:
158      full_args.extend(['-i', self._ssh_identity])
159    if extra_ssh_args:
160      full_args.extend(extra_ssh_args)
161    full_args.append('root@%s' % self._hostname)
162    full_args.append('-p%d' % self._ssh_port)
163    full_args.extend(args)
164    return full_args
165
166  def _FormSCPCommandLine(self, src, dst, extra_scp_args=None):
167    """Constructs a subprocess-suitable command line for `scp'.
168
169    Note: this function is not designed to work with IPv6 addresses, which need
170    to have their addresses enclosed in brackets and a '-6' flag supplied
171    in order to be properly parsed by `scp'.
172    """
173    assert not self.local, "Cannot use SCP on local target."
174
175    args = ['scp', '-P', str(self._ssh_port)] + self._ssh_args
176    if self._ssh_identity:
177      args.extend(['-i', self._ssh_identity])
178    if extra_scp_args:
179      args.extend(extra_scp_args)
180    args += [src, dst]
181    return args
182
183  def _FormSCPToRemote(self,
184                       source,
185                       remote_dest,
186                       extra_scp_args=None,
187                       user='root'):
188    return self._FormSCPCommandLine(source,
189                                    '%s@%s:%s' % (user, self._hostname,
190                                                  remote_dest),
191                                    extra_scp_args=extra_scp_args)
192
193  def _FormSCPFromRemote(self,
194                         remote_source,
195                         dest,
196                         extra_scp_args=None,
197                         user='root'):
198    return self._FormSCPCommandLine('%s@%s:%s' % (user, self._hostname,
199                                                  remote_source),
200                                    dest,
201                                    extra_scp_args=extra_scp_args)
202
203  def _RemoveSSHWarnings(self, toClean):
204    """Removes specific ssh warning lines from a string.
205
206    Args:
207      toClean: A string that may be containing multiple lines.
208
209    Returns:
210      A copy of toClean with all the Warning lines removed.
211    """
212    # Remove the Warning about connecting to a new host for the first time.
213    return re.sub(
214        r'Warning: Permanently added [^\n]* to the list of known hosts.\s\n',
215        '', toClean)
216
217  def RunCmdOnDevice(self, args, cwd=None, quiet=False):
218    stdout, stderr = GetAllCmdOutput(
219        self.FormSSHCommandLine(args),
220        cwd,
221        quiet=quiet)
222    # The initial login will add the host to the hosts file but will also print
223    # a warning to stderr that we need to remove.
224    stderr = self._RemoveSSHWarnings(stderr)
225    return stdout, stderr
226
227  def TryLogin(self):
228    logging.debug('TryLogin()')
229    assert not self.local
230    stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
231    if stderr != '':
232      if 'Host key verification failed' in stderr:
233        raise LoginException(('%s host key verification failed. ' +
234                              'SSH to it manually to fix connectivity.') %
235                             self._hostname)
236      if 'Operation timed out' in stderr:
237        raise LoginException('Timed out while logging into %s' % self._hostname)
238      if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
239        raise LoginException('Permissions for %s are too open. To fix this,\n'
240                             'chmod 600 %s' % (self._ssh_identity,
241                                               self._ssh_identity))
242      if 'Permission denied (publickey,keyboard-interactive)' in stderr:
243        raise KeylessLoginRequiredException('Need to set up ssh auth for %s' %
244                                            self._hostname)
245      if 'Could not resolve hostname' in stderr:
246        raise DNSFailureException('Unable to resolve the hostname for: %s' %
247                                  self._hostname)
248      raise LoginException('While logging into %s, got %s' % (self._hostname,
249                                                              stderr))
250    if stdout != 'root\n':
251      raise LoginException('Logged into %s, expected $USER=root, but got %s.' %
252                           (self._hostname, stdout))
253
254  def FileExistsOnDevice(self, file_name):
255    if self.local:
256      return os.path.exists(file_name)
257
258    stdout, stderr = self.RunCmdOnDevice(
259        [
260            'if', 'test', '-e', file_name, ';', 'then', 'echo', '1', ';', 'fi'
261        ],
262        quiet=True)
263    if stderr != '':
264      if "Connection timed out" in stderr:
265        raise OSError('Machine wasn\'t responding to ssh: %s' % stderr)
266      raise OSError('Unexpected error: %s' % stderr)
267    exists = stdout == '1\n'
268    logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
269    return exists
270
271  def PushFile(self, filename, remote_filename):
272    if self.local:
273      args = ['cp', '-r', filename, remote_filename]
274      stdout, stderr = GetAllCmdOutput(args, quiet=True)
275      if stderr != '':
276        raise OSError('No such file or directory %s' % stderr)
277      return
278
279    args = self._FormSCPToRemote(
280        os.path.abspath(filename),
281        remote_filename,
282        extra_scp_args=['-r'])
283
284    stdout, stderr = GetAllCmdOutput(args, quiet=True)
285    stderr = self._RemoveSSHWarnings(stderr)
286    if stderr != '':
287      raise OSError('No such file or directory %s' % stderr)
288
289  def PushContents(self, text, remote_filename):
290    logging.debug("PushContents(<text>, %s)" % remote_filename)
291    with tempfile.NamedTemporaryFile() as f:
292      f.write(text)
293      f.flush()
294      self.PushFile(f.name, remote_filename)
295
296  def GetFile(self, filename, destfile=None):
297    """Copies a local file |filename| to |destfile| on the device.
298
299    Args:
300      filename: The name of the local source file.
301      destfile: The name of the file to copy to, and if it is not specified
302        then it is the basename of the source file.
303
304    """
305    logging.debug("GetFile(%s, %s)" % (filename, destfile))
306    if self.local:
307      if destfile is not None and destfile != filename:
308        shutil.copyfile(filename, destfile)
309      return
310
311    if destfile is None:
312      destfile = os.path.basename(filename)
313    args = self._FormSCPFromRemote(filename, os.path.abspath(destfile))
314
315    stdout, stderr = GetAllCmdOutput(args, quiet=True)
316    stderr = self._RemoveSSHWarnings(stderr)
317    if stderr != '':
318      raise OSError('No such file or directory %s' % stderr)
319
320  def GetFileContents(self, filename):
321    """Get the contents of a file on the device.
322
323    Args:
324      filename: The name of the file on the device.
325
326    Returns:
327      A string containing the contents of the file.
328    """
329    # TODO: handle the self.local case
330    assert not self.local
331    t = tempfile.NamedTemporaryFile()
332    self.GetFile(filename, t.name)
333    with open(t.name, 'r') as f2:
334      res = f2.read()
335      logging.debug("GetFileContents(%s)->%s" % (filename, res))
336      f2.close()
337      return res
338
339  def ListProcesses(self):
340    """Returns (pid, cmd, ppid, state) of all processes on the device."""
341    stdout, stderr = self.RunCmdOnDevice(
342        [
343            '/bin/ps', '--no-headers', '-A', '-o', 'pid,ppid,args:4096,state'
344        ],
345        quiet=True)
346    assert stderr == '', stderr
347    procs = []
348    for l in stdout.split('\n'):
349      if l == '':
350        continue
351      m = re.match(r'^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
352      assert m
353      procs.append((int(m.group(1)), m.group(3).rstrip(), int(m.group(2)),
354                    m.group(4)))
355    logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
356    return procs
357
358  def _GetSessionManagerPid(self, procs):
359    """Returns the pid of the session_manager process, given the list of
360    processes."""
361    for pid, process, _, _ in procs:
362      argv = process.split()
363      if argv and os.path.basename(argv[0]) == 'session_manager':
364        return pid
365    return None
366
367  def GetChromeProcess(self):
368    """Locates the the main chrome browser process.
369
370    Chrome on cros is usually in /opt/google/chrome, but could be in
371    /usr/local/ for developer workflows - debug chrome is too large to fit on
372    rootfs.
373
374    Chrome spawns multiple processes for renderers. pids wrap around after they
375    are exhausted so looking for the smallest pid is not always correct. We
376    locate the session_manager's pid, and look for the chrome process that's an
377    immediate child. This is the main browser process.
378    """
379    procs = self.ListProcesses()
380    session_manager_pid = self._GetSessionManagerPid(procs)
381    if not session_manager_pid:
382      return None
383
384    # Find the chrome process that is the child of the session_manager.
385    for pid, process, ppid, _ in procs:
386      if ppid != session_manager_pid:
387        continue
388      for regex in _CHROME_PROCESS_REGEX:
389        path_match = re.match(regex, process)
390        if path_match is not None:
391          return {'pid': pid, 'path': path_match.group(), 'args': process}
392    return None
393
394  def GetChromePid(self):
395    """Returns pid of main chrome browser process."""
396    result = self.GetChromeProcess()
397    if result and 'pid' in result:
398      return result['pid']
399    return None
400
401  def RmRF(self, filename):
402    logging.debug("rm -rf %s" % filename)
403    self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
404
405  def Chown(self, filename):
406    self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
407
408  def KillAllMatching(self, predicate):
409    kills = ['kill', '-KILL']
410    for pid, cmd, _, _ in self.ListProcesses():
411      if predicate(cmd):
412        logging.info('Killing %s, pid %d' % cmd, pid)
413        kills.append(pid)
414    logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
415    if len(kills) > 2:
416      self.RunCmdOnDevice(kills, quiet=True)
417    return len(kills) - 2
418
419  def IsServiceRunning(self, service_name):
420    stdout, stderr = self.RunCmdOnDevice(['status', service_name], quiet=True)
421    assert stderr == '', stderr
422    running = 'running, process' in stdout
423    logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
424    return running
425
426  def GetRemotePort(self):
427    netstat = self.RunCmdOnDevice(['netstat', '-ant'])
428    netstat = netstat[0].split('\n')
429    ports_in_use = []
430
431    for line in netstat[2:]:
432      if not line:
433        continue
434      address_in_use = line.split()[3]
435      port_in_use = address_in_use.split(':')[-1]
436      ports_in_use.append(int(port_in_use))
437
438    ports_in_use.extend(self._reserved_ports)
439
440    new_port = sorted(ports_in_use)[-1] + 1
441    self._reserved_ports.append(new_port)
442
443    return new_port
444
445  def IsHTTPServerRunningOnPort(self, port):
446    wget_output = self.RunCmdOnDevice(['wget', 'localhost:%i' % (port), '-T1',
447                                       '-t1'])
448
449    if 'Connection refused' in wget_output[1]:
450      return False
451
452    return True
453
454  def FilesystemMountedAt(self, path):
455    """Returns the filesystem mounted at |path|"""
456    df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
457    df_ary = df_out.split('\n')
458    # 3 lines for title, mount info, and empty line.
459    if len(df_ary) == 3:
460      line_ary = df_ary[1].split()
461      if line_ary:
462        return line_ary[0]
463    return None
464
465  def CryptohomePath(self, user):
466    """Returns the cryptohome mount point for |user|."""
467    stdout, stderr = self.RunCmdOnDevice(['cryptohome-path', 'user', "'%s'" %
468                                          user])
469    if stderr != '':
470      raise OSError('cryptohome-path failed: %s' % stderr)
471    return stdout.rstrip()
472
473  def IsCryptohomeMounted(self, username, is_guest):
474    """Returns True iff |user|'s cryptohome is mounted."""
475    profile_path = self.CryptohomePath(username)
476    mount = self.FilesystemMountedAt(profile_path)
477    mount_prefix = 'guestfs' if is_guest else '/home/.shadow/'
478    return mount and mount.startswith(mount_prefix)
479
480  def TakeScreenShot(self, screenshot_prefix):
481    """Takes a screenshot, useful for debugging failures."""
482    # TODO(achuith): Find a better location for screenshots. Cros autotests
483    # upload everything in /var/log so use /var/log/screenshots for now.
484    SCREENSHOT_DIR = '/var/log/screenshots/'
485    SCREENSHOT_EXT = '.png'
486
487    self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
488    # Large number of screenshots can increase hardware lab bandwidth
489    # dramatically, so keep this number low. crbug.com/524814.
490    for i in xrange(2):
491      screenshot_file = ('%s%s-%d%s' %
492                         (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
493      if not self.FileExistsOnDevice(screenshot_file):
494        self.RunCmdOnDevice([
495            '/usr/local/autotest/bin/screenshot.py', screenshot_file
496        ])
497        return
498    logging.warning('screenshot directory full.')
499
500  def RestartUI(self, clear_enterprise_policy):
501    logging.info('(Re)starting the ui (logs the user out)')
502    if clear_enterprise_policy:
503      self.RunCmdOnDevice(['stop', 'ui'])
504      self.RmRF('/var/lib/whitelist/*')
505      self.RmRF(r'/home/chronos/Local\ State')
506
507    if self.IsServiceRunning('ui'):
508      self.RunCmdOnDevice(['restart', 'ui'])
509    else:
510      self.RunCmdOnDevice(['start', 'ui'])
511
512  def CloseConnection(self):
513    if not self.local:
514      with open(os.devnull, 'w') as devnull:
515        subprocess.call(
516            self.FormSSHCommandLine(['-O', 'exit', self._hostname]),
517            stdout=devnull,
518            stderr=devnull)
519