1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A wrapper around ssh for common operations on a CrOS-based device"""
5import logging
6import os
7import re
8import shutil
9import stat
10import subprocess
11import tempfile
12
13# Some developers' workflow includes running the Chrome process from
14# /usr/local/... instead of the default location. We have to check for both
15# paths in order to support this workflow.
16_CHROME_PROCESS_REGEX = [re.compile(r'^/opt/google/chrome/chrome '),
17                         re.compile(r'^/usr/local/?.*/chrome/chrome ')]
18
19
20def RunCmd(args, cwd=None, quiet=False):
21  """Opens a subprocess to execute a program and returns its return value.
22
23  Args:
24    args: A string or a sequence of program arguments. The program to execute is
25      the string or the first item in the args sequence.
26    cwd: If not None, the subprocess's current directory will be changed to
27      |cwd| before it's executed.
28
29  Returns:
30    Return code from the command execution.
31  """
32  if not quiet:
33    logging.debug(' '.join(args) + ' ' + (cwd or ''))
34  with open(os.devnull, 'w') as devnull:
35    p = subprocess.Popen(args=args,
36                         cwd=cwd,
37                         stdout=devnull,
38                         stderr=devnull,
39                         stdin=devnull,
40                         shell=False)
41    return p.wait()
42
43
44def GetAllCmdOutput(args, cwd=None, quiet=False):
45  """Open a subprocess to execute a program and returns its output.
46
47  Args:
48    args: A string or a sequence of program arguments. The program to execute is
49      the string or the first item in the args sequence.
50    cwd: If not None, the subprocess's current directory will be changed to
51      |cwd| before it's executed.
52
53  Returns:
54    Captures and returns the command's stdout.
55    Prints the command's stderr to logger (which defaults to stdout).
56  """
57  if not quiet:
58    logging.debug(' '.join(args) + ' ' + (cwd or ''))
59  with open(os.devnull, 'w') as devnull:
60    p = subprocess.Popen(args=args,
61                         cwd=cwd,
62                         stdout=subprocess.PIPE,
63                         stderr=subprocess.PIPE,
64                         stdin=devnull)
65    stdout, stderr = p.communicate()
66    if not quiet:
67      logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
68    return stdout, stderr
69
70
71def HasSSH():
72  try:
73    RunCmd(['ssh'], quiet=True)
74    RunCmd(['scp'], quiet=True)
75    logging.debug("HasSSH()->True")
76    return True
77  except OSError:
78    logging.debug("HasSSH()->False")
79    return False
80
81
82class LoginException(Exception):
83  pass
84
85
86class KeylessLoginRequiredException(LoginException):
87  pass
88
89
90class DNSFailureException(LoginException):
91  pass
92
93
94class CrOSInterface(object):
95
96  def __init__(self, hostname=None, ssh_port=None, ssh_identity=None):
97    self._hostname = hostname
98    self._ssh_port = ssh_port
99
100    # List of ports generated from GetRemotePort() that may not be in use yet.
101    self._reserved_ports = []
102
103    if self.local:
104      return
105
106    self._ssh_identity = None
107    self._ssh_args = ['-o ConnectTimeout=5', '-o StrictHostKeyChecking=no',
108                      '-o KbdInteractiveAuthentication=no',
109                      '-o PreferredAuthentications=publickey',
110                      '-o UserKnownHostsFile=/dev/null', '-o ControlMaster=no']
111
112    if ssh_identity:
113      self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
114      os.chmod(self._ssh_identity, stat.S_IREAD)
115
116    # Establish master SSH connection using ControlPersist.
117    # Since only one test will be run on a remote host at a time,
118    # the control socket filename can be telemetry@hostname.
119    self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname
120    with open(os.devnull, 'w') as devnull:
121      subprocess.call(
122          self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']),
123          stdin=devnull,
124          stdout=devnull,
125          stderr=devnull)
126
127  def __enter__(self):
128    return self
129
130  def __exit__(self, *args):
131    self.CloseConnection()
132
133  @property
134  def local(self):
135    return not self._hostname
136
137  @property
138  def hostname(self):
139    return self._hostname
140
141  @property
142  def ssh_port(self):
143    return self._ssh_port
144
145  def FormSSHCommandLine(self, args, extra_ssh_args=None):
146    """Constructs a subprocess-suitable command line for `ssh'.
147    """
148    if self.local:
149      # We run the command through the shell locally for consistency with
150      # how commands are run through SSH (crbug.com/239161). This work
151      # around will be unnecessary once we implement a persistent SSH
152      # connection to run remote commands (crbug.com/239607).
153      return ['sh', '-c', " ".join(args)]
154
155    full_args = ['ssh', '-o ForwardX11=no', '-o ForwardX11Trusted=no', '-n',
156                 '-S', self._ssh_control_file] + self._ssh_args
157    if self._ssh_identity is not None:
158      full_args.extend(['-i', self._ssh_identity])
159    if extra_ssh_args:
160      full_args.extend(extra_ssh_args)
161    full_args.append('root@%s' % self._hostname)
162    full_args.append('-p%d' % self._ssh_port)
163    full_args.extend(args)
164    return full_args
165
166  def _FormSCPCommandLine(self, src, dst, extra_scp_args=None):
167    """Constructs a subprocess-suitable command line for `scp'.
168
169    Note: this function is not designed to work with IPv6 addresses, which need
170    to have their addresses enclosed in brackets and a '-6' flag supplied
171    in order to be properly parsed by `scp'.
172    """
173    assert not self.local, "Cannot use SCP on local target."
174
175    args = ['scp', '-P', str(self._ssh_port)] + self._ssh_args
176    if self._ssh_identity:
177      args.extend(['-i', self._ssh_identity])
178    if extra_scp_args:
179      args.extend(extra_scp_args)
180    args += [src, dst]
181    return args
182
183  def _FormSCPToRemote(self,
184                       source,
185                       remote_dest,
186                       extra_scp_args=None,
187                       user='root'):
188    return self._FormSCPCommandLine(source,
189                                    '%s@%s:%s' % (user, self._hostname,
190                                                  remote_dest),
191                                    extra_scp_args=extra_scp_args)
192
193  def _FormSCPFromRemote(self,
194                         remote_source,
195                         dest,
196                         extra_scp_args=None,
197                         user='root'):
198    return self._FormSCPCommandLine('%s@%s:%s' % (user, self._hostname,
199                                                  remote_source),
200                                    dest,
201                                    extra_scp_args=extra_scp_args)
202
203  def _RemoveSSHWarnings(self, toClean):
204    """Removes specific ssh warning lines from a string.
205
206    Args:
207      toClean: A string that may be containing multiple lines.
208
209    Returns:
210      A copy of toClean with all the Warning lines removed.
211    """
212    # Remove the Warning about connecting to a new host for the first time.
213    return re.sub(
214        r'Warning: Permanently added [^\n]* to the list of known hosts.\s\n',
215        '', toClean)
216
217  def RunCmdOnDevice(self, args, cwd=None, quiet=False):
218    stdout, stderr = GetAllCmdOutput(
219        self.FormSSHCommandLine(args),
220        cwd,
221        quiet=quiet)
222    # The initial login will add the host to the hosts file but will also print
223    # a warning to stderr that we need to remove.
224    stderr = self._RemoveSSHWarnings(stderr)
225    return stdout, stderr
226
227  def TryLogin(self):
228    logging.debug('TryLogin()')
229    assert not self.local
230    stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
231    if stderr != '':
232      if 'Host key verification failed' in stderr:
233        raise LoginException(('%s host key verification failed. ' +
234                              'SSH to it manually to fix connectivity.') %
235                             self._hostname)
236      if 'Operation timed out' in stderr:
237        raise LoginException('Timed out while logging into %s' % self._hostname)
238      if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
239        raise LoginException('Permissions for %s are too open. To fix this,\n'
240                             'chmod 600 %s' % (self._ssh_identity,
241                                               self._ssh_identity))
242      if 'Permission denied (publickey,keyboard-interactive)' in stderr:
243        raise KeylessLoginRequiredException('Need to set up ssh auth for %s' %
244                                            self._hostname)
245      if 'Could not resolve hostname' in stderr:
246        raise DNSFailureException('Unable to resolve the hostname for: %s' %
247                                  self._hostname)
248      raise LoginException('While logging into %s, got %s' % (self._hostname,
249                                                              stderr))
250    if stdout != 'root\n':
251      raise LoginException('Logged into %s, expected $USER=root, but got %s.' %
252                           (self._hostname, stdout))
253
254  def FileExistsOnDevice(self, file_name):
255    if self.local:
256      return os.path.exists(file_name)
257
258    stdout, stderr = self.RunCmdOnDevice(
259        [
260            'if', 'test', '-e', file_name, ';', 'then', 'echo', '1', ';', 'fi'
261        ],
262        quiet=True)
263    if stderr != '':
264      if "Connection timed out" in stderr:
265        raise OSError('Machine wasn\'t responding to ssh: %s' % stderr)
266      raise OSError('Unexpected error: %s' % stderr)
267    exists = stdout == '1\n'
268    logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
269    return exists
270
271  def PushFile(self, filename, remote_filename):
272    if self.local:
273      args = ['cp', '-r', filename, remote_filename]
274      stdout, stderr = GetAllCmdOutput(args, quiet=True)
275      if stderr != '':
276        raise OSError('No such file or directory %s' % stderr)
277      return
278
279    args = self._FormSCPToRemote(
280        os.path.abspath(filename),
281        remote_filename,
282        extra_scp_args=['-r'])
283
284    stdout, stderr = GetAllCmdOutput(args, quiet=True)
285    stderr = self._RemoveSSHWarnings(stderr)
286    if stderr != '':
287      raise OSError('No such file or directory %s' % stderr)
288
289  def PushContents(self, text, remote_filename):
290    logging.debug("PushContents(<text>, %s)" % remote_filename)
291    with tempfile.NamedTemporaryFile() as f:
292      f.write(text)
293      f.flush()
294      self.PushFile(f.name, remote_filename)
295
296  def GetFile(self, filename, destfile=None):
297    """Copies a local file |filename| to |destfile| on the device.
298
299    Args:
300      filename: The name of the local source file.
301      destfile: The name of the file to copy to, and if it is not specified
302        then it is the basename of the source file.
303
304    """
305    logging.debug("GetFile(%s, %s)" % (filename, destfile))
306    if self.local:
307      if destfile is not None and destfile != filename:
308        shutil.copyfile(filename, destfile)
309        return
310      else:
311        raise OSError('No such file or directory %s' % filename)
312
313    if destfile is None:
314      destfile = os.path.basename(filename)
315    args = self._FormSCPFromRemote(filename, os.path.abspath(destfile))
316
317    stdout, stderr = GetAllCmdOutput(args, quiet=True)
318    stderr = self._RemoveSSHWarnings(stderr)
319    if stderr != '':
320      raise OSError('No such file or directory %s' % stderr)
321
322  def GetFileContents(self, filename):
323    """Get the contents of a file on the device.
324
325    Args:
326      filename: The name of the file on the device.
327
328    Returns:
329      A string containing the contents of the file.
330    """
331    with tempfile.NamedTemporaryFile() as t:
332      self.GetFile(filename, t.name)
333      with open(t.name, 'r') as f2:
334        res = f2.read()
335        logging.debug("GetFileContents(%s)->%s" % (filename, res))
336        return res
337
338  def HasSystemd(self):
339    """Return True or False to indicate if systemd is used.
340
341    Note: This function checks to see if the 'systemctl' utilitary
342    is installed. This is only installed along with the systemd daemon.
343    """
344    _, stderr = self.RunCmdOnDevice(['systemctl'], quiet=True)
345    return stderr == ''
346
347  def ListProcesses(self):
348    """Returns (pid, cmd, ppid, state) of all processes on the device."""
349    stdout, stderr = self.RunCmdOnDevice(
350        [
351            '/bin/ps', '--no-headers', '-A', '-o', 'pid,ppid,args:4096,state'
352        ],
353        quiet=True)
354    assert stderr == '', stderr
355    procs = []
356    for l in stdout.split('\n'):
357      if l == '':
358        continue
359      m = re.match(r'^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
360      assert m
361      procs.append((int(m.group(1)), m.group(3).rstrip(), int(m.group(2)),
362                    m.group(4)))
363    logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
364    return procs
365
366  def _GetSessionManagerPid(self, procs):
367    """Returns the pid of the session_manager process, given the list of
368    processes."""
369    for pid, process, _, _ in procs:
370      argv = process.split()
371      if argv and os.path.basename(argv[0]) == 'session_manager':
372        return pid
373    return None
374
375  def GetChromeProcess(self):
376    """Locates the the main chrome browser process.
377
378    Chrome on cros is usually in /opt/google/chrome, but could be in
379    /usr/local/ for developer workflows - debug chrome is too large to fit on
380    rootfs.
381
382    Chrome spawns multiple processes for renderers. pids wrap around after they
383    are exhausted so looking for the smallest pid is not always correct. We
384    locate the session_manager's pid, and look for the chrome process that's an
385    immediate child. This is the main browser process.
386    """
387    procs = self.ListProcesses()
388    session_manager_pid = self._GetSessionManagerPid(procs)
389    if not session_manager_pid:
390      return None
391
392    # Find the chrome process that is the child of the session_manager.
393    for pid, process, ppid, _ in procs:
394      if ppid != session_manager_pid:
395        continue
396      for regex in _CHROME_PROCESS_REGEX:
397        path_match = re.match(regex, process)
398        if path_match is not None:
399          return {'pid': pid, 'path': path_match.group(), 'args': process}
400    return None
401
402  def GetChromePid(self):
403    """Returns pid of main chrome browser process."""
404    result = self.GetChromeProcess()
405    if result and 'pid' in result:
406      return result['pid']
407    return None
408
409  def RmRF(self, filename):
410    logging.debug("rm -rf %s" % filename)
411    self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
412
413  def Chown(self, filename):
414    self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
415
416  def KillAllMatching(self, predicate):
417    kills = ['kill', '-KILL']
418    for pid, cmd, _, _ in self.ListProcesses():
419      if predicate(cmd):
420        logging.info('Killing %s, pid %d' % cmd, pid)
421        kills.append(pid)
422    logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
423    if len(kills) > 2:
424      self.RunCmdOnDevice(kills, quiet=True)
425    return len(kills) - 2
426
427  def IsServiceRunning(self, service_name):
428    """Check with the init daemon if the given service is running."""
429    if self.HasSystemd():
430      # Querying for the pid of the service will return 'MainPID=0' if
431      # the service is not running.
432      stdout, stderr = self.RunCmdOnDevice(
433          ['systemctl', 'show', '-p', 'MainPID', service_name], quiet=True)
434      running = int(stdout.split('=')[1]) != 0
435    else:
436      stdout, stderr = self.RunCmdOnDevice(['status', service_name], quiet=True)
437      running = 'running, process' in stdout
438    assert stderr == '', stderr
439    logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
440    return running
441
442  def GetRemotePort(self):
443    netstat = self.RunCmdOnDevice(['netstat', '-ant'])
444    netstat = netstat[0].split('\n')
445    ports_in_use = []
446
447    for line in netstat[2:]:
448      if not line:
449        continue
450      address_in_use = line.split()[3]
451      port_in_use = address_in_use.split(':')[-1]
452      ports_in_use.append(int(port_in_use))
453
454    ports_in_use.extend(self._reserved_ports)
455
456    new_port = sorted(ports_in_use)[-1] + 1
457    self._reserved_ports.append(new_port)
458
459    return new_port
460
461  def IsHTTPServerRunningOnPort(self, port):
462    wget_output = self.RunCmdOnDevice(['wget', 'localhost:%i' % (port), '-T1',
463                                       '-t1'])
464
465    if 'Connection refused' in wget_output[1]:
466      return False
467
468    return True
469
470  def _GetMountSourceAndTarget(self, path):
471    df_out, _ = self.RunCmdOnDevice(['/bin/df', '--output=source,target', path])
472    df_ary = df_out.split('\n')
473    # 3 lines for title, mount info, and empty line.
474    if len(df_ary) == 3:
475      line_ary = df_ary[1].split()
476      return line_ary if len(line_ary) == 2 else None
477    return None
478
479  def FilesystemMountedAt(self, path):
480    """Returns the filesystem mounted at |path|"""
481    mount_info = self._GetMountSourceAndTarget(path)
482    return mount_info[0] if mount_info else None
483
484  def CryptohomePath(self, user):
485    """Returns the cryptohome mount point for |user|."""
486    stdout, stderr = self.RunCmdOnDevice(['cryptohome-path', 'user', "'%s'" %
487                                          user])
488    if stderr != '':
489      raise OSError('cryptohome-path failed: %s' % stderr)
490    return stdout.rstrip()
491
492  def IsCryptohomeMounted(self, username, is_guest):
493    """Returns True iff |user|'s cryptohome is mounted."""
494    profile_path = self.CryptohomePath(username)
495    mount_info = self._GetMountSourceAndTarget(profile_path)
496    if mount_info:
497      # Checks if the filesytem at |profile_path| is mounted on |profile_path|
498      # itself. Before mounting cryptohome, it shows an upper directory (/home).
499      is_guestfs = (mount_info[0] == 'guestfs')
500      return is_guestfs == is_guest and mount_info[1] == profile_path
501    return False
502
503  def TakeScreenshot(self, file_path):
504    stdout, stderr = self.RunCmdOnDevice(
505        ['/usr/local/autotest/bin/screenshot.py', file_path])
506    return stdout == '' and stderr == ''
507
508  def TakeScreenshotWithPrefix(self, screenshot_prefix):
509    """Takes a screenshot, useful for debugging failures."""
510    # TODO(achuith): Find a better location for screenshots. Cros autotests
511    # upload everything in /var/log so use /var/log/screenshots for now.
512    SCREENSHOT_DIR = '/var/log/screenshots/'
513    SCREENSHOT_EXT = '.png'
514
515    self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
516    # Large number of screenshots can increase hardware lab bandwidth
517    # dramatically, so keep this number low. crbug.com/524814.
518    for i in xrange(2):
519      screenshot_file = ('%s%s-%d%s' %
520                         (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
521      if not self.FileExistsOnDevice(screenshot_file):
522        return self.TakeScreenshot(screenshot_file)
523    logging.warning('screenshot directory full.')
524    return False
525
526  def GetArchName(self):
527    return self.RunCmdOnDevice(['uname', '-m'])[0]
528
529  def IsRunningOnVM(self):
530    return self.RunCmdOnDevice(['crossystem', 'inside_vm'])[0] != '0'
531
532  def LsbReleaseValue(self, key, default):
533    """/etc/lsb-release is a file with key=value pairs."""
534    lines = self.GetFileContents('/etc/lsb-release').split('\n')
535    for l in lines:
536      m = re.match(r'([^=]*)=(.*)', l)
537      if m and m.group(1) == key:
538        return m.group(2)
539    return default
540
541  def GetDeviceTypeName(self):
542    """DEVICETYPE in /etc/lsb-release is CHROMEBOOK, CHROMEBIT, etc."""
543    return self.LsbReleaseValue(key='DEVICETYPE', default='CHROMEBOOK')
544
545  def RestartUI(self, clear_enterprise_policy):
546    logging.info('(Re)starting the ui (logs the user out)')
547    start_cmd = ['start', 'ui']
548    restart_cmd = ['restart', 'ui']
549    stop_cmd = ['stop', 'ui']
550    if self.HasSystemd():
551      start_cmd.insert(0, 'systemctl')
552      restart_cmd.insert(0, 'systemctl')
553      stop_cmd.insert(0, 'systemctl')
554    if clear_enterprise_policy:
555      self.RunCmdOnDevice(stop_cmd)
556      self.RmRF('/var/lib/whitelist/*')
557      self.RmRF(r'/home/chronos/Local\ State')
558
559    if self.IsServiceRunning('ui'):
560      self.RunCmdOnDevice(restart_cmd)
561    else:
562      self.RunCmdOnDevice(start_cmd)
563
564  def CloseConnection(self):
565    if not self.local:
566      with open(os.devnull, 'w') as devnull:
567        subprocess.call(
568            self.FormSSHCommandLine(['-O', 'exit', self._hostname]),
569            stdout=devnull,
570            stderr=devnull)
571