1# pylint: disable-msg=C0111
2import os, time, signal, socket, re, fnmatch, logging, threading
3import paramiko
4
5from autotest_lib.client.common_lib import utils, error, global_config
6from autotest_lib.server import subcommand
7from autotest_lib.server.hosts import abstract_ssh
8
9
10class ParamikoHost(abstract_ssh.AbstractSSHHost):
11    KEEPALIVE_TIMEOUT_SECONDS = 30
12    CONNECT_TIMEOUT_SECONDS = 30
13    CONNECT_TIMEOUT_RETRIES = 3
14    BUFFSIZE = 2**16
15
16    def _initialize(self, hostname, *args, **dargs):
17        super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs)
18
19        # paramiko is very noisy, tone down the logging
20        paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR)
21
22        self.keys = self.get_user_keys(hostname)
23        self.pid = None
24
25
26    @staticmethod
27    def _load_key(path):
28        """Given a path to a private key file, load the appropriate keyfile.
29
30        Tries to load the file as both an RSAKey and a DSAKey. If the file
31        cannot be loaded as either type, returns None."""
32        try:
33            return paramiko.DSSKey.from_private_key_file(path)
34        except paramiko.SSHException:
35            try:
36                return paramiko.RSAKey.from_private_key_file(path)
37            except paramiko.SSHException:
38                return None
39
40
41    @staticmethod
42    def _parse_config_line(line):
43        """Given an ssh config line, return a (key, value) tuple for the
44        config value listed in the line, or (None, None)"""
45        match = re.match(r"\s*(\w+)\s*=?(.*)\n", line)
46        if match:
47            return match.groups()
48        else:
49            return None, None
50
51
52    @staticmethod
53    def get_user_keys(hostname):
54        """Returns a mapping of path -> paramiko.PKey entries available for
55        this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa)
56        as well as any IdentityFile entries in the standard ssh config files.
57        """
58        raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"]
59        for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"):
60            config_path = os.path.expanduser(config_path)
61            if not os.path.exists(config_path):
62                continue
63            host_pattern = "*"
64            config_lines = open(config_path).readlines()
65            for line in config_lines:
66                key, value = ParamikoHost._parse_config_line(line)
67                if key == "Host":
68                    host_pattern = value
69                elif (key == "IdentityFile"
70                      and fnmatch.fnmatch(hostname, host_pattern)):
71                    raw_identity_files.append(value)
72
73        # drop any files that use percent-escapes; we don't support them
74        identity_files = []
75        UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"]
76        for path in raw_identity_files:
77            # skip this path if it uses % escapes
78            if sum((escape in path) for escape in UNSUPPORTED_ESCAPES):
79                continue
80            path = os.path.expanduser(path)
81            if os.path.exists(path):
82                identity_files.append(path)
83
84        # load up all the keys that we can and return them
85        user_keys = {}
86        for path in identity_files:
87            key = ParamikoHost._load_key(path)
88            if key:
89                user_keys[path] = key
90
91        # load up all the ssh agent keys
92        use_sshagent = global_config.global_config.get_config_value(
93            'AUTOSERV', 'use_sshagent_with_paramiko', type=bool)
94        if use_sshagent:
95            ssh_agent = paramiko.Agent()
96            for i, key in enumerate(ssh_agent.get_keys()):
97                user_keys['agent-key-%d' % i] = key
98
99        return user_keys
100
101
102    def _check_transport_error(self, transport):
103        error = transport.get_exception()
104        if error:
105            transport.close()
106            raise error
107
108
109    def _connect_socket(self):
110        """Return a socket for use in instantiating a paramiko transport. Does
111        not have to be a literal socket, it can be anything that the
112        paramiko.Transport constructor accepts."""
113        return self.hostname, self.port
114
115
116    def _connect_transport(self, pkey):
117        for _ in xrange(self.CONNECT_TIMEOUT_RETRIES):
118            transport = paramiko.Transport(self._connect_socket())
119            completed = threading.Event()
120            transport.start_client(completed)
121            completed.wait(self.CONNECT_TIMEOUT_SECONDS)
122            if completed.isSet():
123                self._check_transport_error(transport)
124                completed.clear()
125                transport.auth_publickey(self.user, pkey, completed)
126                completed.wait(self.CONNECT_TIMEOUT_SECONDS)
127                if completed.isSet():
128                    self._check_transport_error(transport)
129                    if not transport.is_authenticated():
130                        transport.close()
131                        raise paramiko.AuthenticationException()
132                    return transport
133            logging.warning("SSH negotiation (%s:%d) timed out, retrying",
134                         self.hostname, self.port)
135            # HACK: we can't count on transport.join not hanging now, either
136            transport.join = lambda: None
137            transport.close()
138        logging.error("SSH negotation (%s:%d) has timed out %s times, "
139                      "giving up", self.hostname, self.port,
140                      self.CONNECT_TIMEOUT_RETRIES)
141        raise error.AutoservSSHTimeout("SSH negotiation timed out")
142
143
144    def _init_transport(self):
145        for path, key in self.keys.iteritems():
146            try:
147                logging.debug("Connecting with %s", path)
148                transport = self._connect_transport(key)
149                transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS)
150                self.transport = transport
151                self.pid = os.getpid()
152                return
153            except paramiko.AuthenticationException:
154                logging.debug("Authentication failure")
155        else:
156            raise error.AutoservSshPermissionDeniedError(
157                "Permission denied using all keys available to ParamikoHost",
158                utils.CmdResult())
159
160
161    def _open_channel(self, timeout):
162        start_time = time.time()
163        if os.getpid() != self.pid:
164            if self.pid is not None:
165                # HACK: paramiko tries to join() on its worker thread
166                # and this just hangs on linux after a fork()
167                self.transport.join = lambda: None
168                self.transport.atfork()
169                join_hook = lambda cmd: self._close_transport()
170                subcommand.subcommand.register_join_hook(join_hook)
171                logging.debug("Reopening SSH connection after a process fork")
172            self._init_transport()
173
174        channel = None
175        try:
176            channel = self.transport.open_session()
177        except (socket.error, paramiko.SSHException, EOFError), e:
178            logging.warning("Exception occured while opening session: %s", e)
179            if time.time() - start_time >= timeout:
180                raise error.AutoservSSHTimeout("ssh failed: %s" % e)
181
182        if not channel:
183            # we couldn't get a channel; re-initing transport should fix that
184            try:
185                self.transport.close()
186            except Exception, e:
187                logging.debug("paramiko.Transport.close failed with %s", e)
188            self._init_transport()
189            return self.transport.open_session()
190        else:
191            return channel
192
193
194    def _close_transport(self):
195        if os.getpid() == self.pid:
196            self.transport.close()
197
198
199    def close(self):
200        super(ParamikoHost, self).close()
201        self._close_transport()
202
203
204    @classmethod
205    def _exhaust_stream(cls, tee, output_list, recvfunc):
206        while True:
207            try:
208                output_list.append(recvfunc(cls.BUFFSIZE))
209            except socket.timeout:
210                return
211            tee.write(output_list[-1])
212            if not output_list[-1]:
213                return
214
215
216    @classmethod
217    def __send_stdin(cls, channel, stdin):
218        if not stdin or not channel.send_ready():
219            # nothing more to send or just no space to send now
220            return
221
222        sent = channel.send(stdin[:cls.BUFFSIZE])
223        if not sent:
224            logging.warning('Could not send a single stdin byte.')
225        else:
226            stdin = stdin[sent:]
227            if not stdin:
228                # no more stdin input, close output direction
229                channel.shutdown_write()
230        return stdin
231
232
233    def run(self, command, timeout=3600, ignore_status=False,
234            stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS,
235            connect_timeout=30, stdin=None, verbose=True, args=(),
236            ignore_timeout=False):
237        """
238        Run a command on the remote host.
239        @see common_lib.hosts.host.run()
240
241        @param connect_timeout: connection timeout (in seconds)
242        @param options: string with additional ssh command options
243        @param verbose: log the commands
244        @param ignore_timeout: bool True command timeouts should be
245                               ignored.  Will return None on command timeout.
246
247        @raises AutoservRunError: if the command failed
248        @raises AutoservSSHTimeout: ssh connection has timed out
249        """
250
251        stdout = utils.get_stream_tee_file(
252                stdout_tee, utils.DEFAULT_STDOUT_LEVEL,
253                prefix=utils.STDOUT_PREFIX)
254        stderr = utils.get_stream_tee_file(
255                stderr_tee, utils.get_stderr_level(ignore_status),
256                prefix=utils.STDERR_PREFIX)
257
258        for arg in args:
259            command += ' "%s"' % utils.sh_escape(arg)
260
261        if verbose:
262            logging.debug("Running (ssh-paramiko) '%s'", command)
263
264        # start up the command
265        start_time = time.time()
266        try:
267            channel = self._open_channel(timeout)
268            channel.exec_command(command)
269        except (socket.error, paramiko.SSHException, EOFError), e:
270            # This has to match the string from paramiko *exactly*.
271            if str(e) != 'Channel closed.':
272                raise error.AutoservSSHTimeout("ssh failed: %s" % e)
273
274        # pull in all the stdout, stderr until the command terminates
275        raw_stdout, raw_stderr = [], []
276        timed_out = False
277        while not channel.exit_status_ready():
278            if channel.recv_ready():
279                raw_stdout.append(channel.recv(self.BUFFSIZE))
280                stdout.write(raw_stdout[-1])
281            if channel.recv_stderr_ready():
282                raw_stderr.append(channel.recv_stderr(self.BUFFSIZE))
283                stderr.write(raw_stderr[-1])
284            if timeout and time.time() - start_time > timeout:
285                timed_out = True
286                break
287            stdin = self.__send_stdin(channel, stdin)
288            time.sleep(1)
289
290        if timed_out:
291            exit_status = -signal.SIGTERM
292        else:
293            exit_status = channel.recv_exit_status()
294        channel.settimeout(10)
295        self._exhaust_stream(stdout, raw_stdout, channel.recv)
296        self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr)
297        channel.close()
298        duration = time.time() - start_time
299
300        # create the appropriate results
301        stdout = "".join(raw_stdout)
302        stderr = "".join(raw_stderr)
303        result = utils.CmdResult(command, stdout, stderr, exit_status,
304                                 duration)
305        if exit_status == -signal.SIGHUP:
306            msg = "ssh connection unexpectedly terminated"
307            raise error.AutoservRunError(msg, result)
308        if timed_out:
309            logging.warning('Paramiko command timed out after %s sec: %s', timeout,
310                         command)
311            if not ignore_timeout:
312                raise error.AutoservRunError("command timed out", result)
313        if not ignore_status and exit_status:
314            raise error.AutoservRunError(command, result)
315        return result
316