1# Lint as: python2, python3
2# Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>,
3#                Benjamin Poirier, Ryan Stutsman
4# Released under the GPL v2
5"""
6Miscellaneous small functions.
7
8DO NOT import this file directly - it is mixed in by server/utils.py,
9import that instead
10"""
11
12from __future__ import absolute_import
13from __future__ import division
14from __future__ import print_function
15
16import atexit, os, re, shutil, textwrap, sys, tempfile, types
17import six
18
19from autotest_lib.client.common_lib import barrier, utils
20from autotest_lib.server import subcommand
21
22
23# A dictionary of pid and a list of tmpdirs for that pid
24__tmp_dirs = {}
25
26
27def scp_remote_escape(filename):
28    """
29    Escape special characters from a filename so that it can be passed
30    to scp (within double quotes) as a remote file.
31
32    Bis-quoting has to be used with scp for remote files, "bis-quoting"
33    as in quoting x 2
34    scp does not support a newline in the filename
35
36    Args:
37            filename: the filename string to escape.
38
39    Returns:
40            The escaped filename string. The required englobing double
41            quotes are NOT added and so should be added at some point by
42            the caller.
43    """
44    escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}'
45
46    new_name= []
47    for char in filename:
48        if char in escape_chars:
49            new_name.append("\\%s" % (char,))
50        else:
51            new_name.append(char)
52
53    return utils.sh_escape("".join(new_name))
54
55
56def get(location, local_copy = False):
57    """Get a file or directory to a local temporary directory.
58
59    Args:
60            location: the source of the material to get. This source may
61                    be one of:
62                    * a local file or directory
63                    * a URL (http or ftp)
64                    * a python file-like object
65
66    Returns:
67            The location of the file or directory where the requested
68            content was saved. This will be contained in a temporary
69            directory on the local host. If the material to get was a
70            directory, the location will contain a trailing '/'
71    """
72    tmpdir = get_tmp_dir()
73
74    # location is a file-like object
75    if hasattr(location, "read"):
76        tmpfile = os.path.join(tmpdir, "file")
77        tmpfileobj = open(tmpfile, 'w')
78        shutil.copyfileobj(location, tmpfileobj)
79        tmpfileobj.close()
80        return tmpfile
81
82    if isinstance(location, six.string_types):
83        # location is a URL
84        if location.startswith('http') or location.startswith('ftp'):
85            tmpfile = os.path.join(tmpdir, os.path.basename(location))
86            utils.urlretrieve(location, tmpfile)
87            return tmpfile
88        # location is a local path
89        elif os.path.exists(os.path.abspath(location)):
90            if not local_copy:
91                if os.path.isdir(location):
92                    return location.rstrip('/') + '/'
93                else:
94                    return location
95            tmpfile = os.path.join(tmpdir, os.path.basename(location))
96            if os.path.isdir(location):
97                tmpfile += '/'
98                shutil.copytree(location, tmpfile, symlinks=True)
99                return tmpfile
100            shutil.copyfile(location, tmpfile)
101            return tmpfile
102        # location is just a string, dump it to a file
103        else:
104            tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir)
105            tmpfileobj = os.fdopen(tmpfd, 'w')
106            tmpfileobj.write(location)
107            tmpfileobj.close()
108            return tmpfile
109
110
111def get_tmp_dir():
112    """Return the pathname of a directory on the host suitable
113    for temporary file storage.
114
115    The directory and its content will be deleted automatically
116    at the end of the program execution if they are still present.
117    """
118    dir_name = tempfile.mkdtemp(prefix="autoserv-")
119    pid = os.getpid()
120    if not pid in __tmp_dirs:
121        __tmp_dirs[pid] = []
122    __tmp_dirs[pid].append(dir_name)
123    return dir_name
124
125
126def __clean_tmp_dirs():
127    """Erase temporary directories that were created by the get_tmp_dir()
128    function and that are still present.
129    """
130    pid = os.getpid()
131    if pid not in __tmp_dirs:
132        return
133    for dir in __tmp_dirs[pid]:
134        try:
135            shutil.rmtree(dir)
136        except OSError as e:
137            if e.errno == 2:
138                pass
139    __tmp_dirs[pid] = []
140atexit.register(__clean_tmp_dirs)
141subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs())
142
143
144def unarchive(host, source_material):
145    """Uncompress and untar an archive on a host.
146
147    If the "source_material" is compresses (according to the file
148    extension) it will be uncompressed. Supported compression formats
149    are gzip and bzip2. Afterwards, if the source_material is a tar
150    archive, it will be untarred.
151
152    Args:
153            host: the host object on which the archive is located
154            source_material: the path of the archive on the host
155
156    Returns:
157            The file or directory name of the unarchived source material.
158            If the material is a tar archive, it will be extracted in the
159            directory where it is and the path returned will be the first
160            entry in the archive, assuming it is the topmost directory.
161            If the material is not an archive, nothing will be done so this
162            function is "harmless" when it is "useless".
163    """
164    # uncompress
165    if (source_material.endswith(".gz") or
166            source_material.endswith(".gzip")):
167        host.run('gunzip "%s"' % (utils.sh_escape(source_material)))
168        source_material= ".".join(source_material.split(".")[:-1])
169    elif source_material.endswith("bz2"):
170        host.run('bunzip2 "%s"' % (utils.sh_escape(source_material)))
171        source_material= ".".join(source_material.split(".")[:-1])
172
173    # untar
174    if source_material.endswith(".tar"):
175        retval= host.run('tar -C "%s" -xvf "%s"' % (
176                utils.sh_escape(os.path.dirname(source_material)),
177                utils.sh_escape(source_material),))
178        source_material= os.path.join(os.path.dirname(source_material),
179                retval.stdout.split()[0])
180
181    return source_material
182
183
184def get_server_dir():
185    path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__)
186    return os.path.abspath(path)
187
188
189def find_pid(command):
190    for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'):
191        (pid, cmd) = line.split(None, 1)
192        if re.search(command, cmd):
193            return int(pid)
194    return None
195
196
197def default_mappings(machines):
198    """
199    Returns a simple mapping in which all machines are assigned to the
200    same key.  Provides the default behavior for
201    form_ntuples_from_machines. """
202    mappings = {}
203    failures = []
204
205    mach = machines[0]
206    mappings['ident'] = [mach]
207    if len(machines) > 1:
208        machines = machines[1:]
209        for machine in machines:
210            mappings['ident'].append(machine)
211
212    return (mappings, failures)
213
214
215def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings):
216    """Returns a set of ntuples from machines where the machines in an
217       ntuple are in the same mapping, and a set of failures which are
218       (machine name, reason) tuples."""
219    ntuples = []
220    (mappings, failures) = mapping_func(machines)
221
222    # now run through the mappings and create n-tuples.
223    # throw out the odd guys out
224    for key in mappings:
225        key_machines = mappings[key]
226        total_machines = len(key_machines)
227
228        # form n-tuples
229        while len(key_machines) >= n:
230            ntuples.append(key_machines[0:n])
231            key_machines = key_machines[n:]
232
233        for mach in key_machines:
234            failures.append((mach, "machine can not be tupled"))
235
236    return (ntuples, failures)
237
238
239def parse_machine(machine, user='root', password='', port=22):
240    """
241    Parse the machine string user:pass@host:port and return it separately,
242    if the machine string is not complete, use the default parameters
243    when appropriate.
244    """
245
246    if '@' in machine:
247        user, machine = machine.split('@', 1)
248
249    if ':' in user:
250        user, password = user.split(':', 1)
251
252    # Brackets are required to protect an IPv6 address whenever a
253    # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to
254    # it. Do not attempt to extract a (non-existent) port number from
255    # an unprotected/bare IPv6 address "xx::xx".
256    # In the Python >= 3.3 future, 'import ipaddress' will parse
257    # addresses; and maybe more.
258    bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine)
259
260    # Extract trailing :port number if any.
261    if not bare_ipv6 and re.search(r':\d*$', machine):
262        machine, port = machine.rsplit(':', 1)
263        port = int(port)
264
265    # Strip any IPv6 brackets (ssh does not support them).
266    # We'll add them back later for rsync, scp, etc.
267    if machine[0] == '[' and machine[-1] == ']':
268        machine = machine[1:-1]
269
270    if not machine or not user:
271        raise ValueError
272
273    return machine, user, password, port
274
275
276def get_public_key():
277    """
278    Return a valid string ssh public key for the user executing autoserv or
279    autotest. If there's no DSA or RSA public key, create a DSA keypair with
280    ssh-keygen and return it.
281    """
282
283    ssh_conf_path = os.path.expanduser('~/.ssh')
284
285    dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub')
286    dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa')
287
288    rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub')
289    rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa')
290
291    has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \
292        os.path.isfile(dsa_private_key_path)
293    has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \
294        os.path.isfile(rsa_private_key_path)
295
296    if has_dsa_keypair:
297        print('DSA keypair found, using it')
298        public_key_path = dsa_public_key_path
299
300    elif has_rsa_keypair:
301        print('RSA keypair found, using it')
302        public_key_path = rsa_public_key_path
303
304    else:
305        print('Neither RSA nor DSA keypair found, creating DSA ssh key pair')
306        utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path)
307        public_key_path = dsa_public_key_path
308
309    public_key = open(public_key_path, 'r')
310    public_key_str = public_key.read()
311    public_key.close()
312
313    return public_key_str
314