1# Copyright 2015 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""This module provides some utilities used by LXC and its tools.
6"""
7
8import logging
9import os
10import re
11import shutil
12import tempfile
13import unittest
14from contextlib import contextmanager
15
16import common
17from autotest_lib.client.bin import utils
18from autotest_lib.client.common_lib import error
19from autotest_lib.client.common_lib.cros.network import interface
20from autotest_lib.client.common_lib import global_config
21from autotest_lib.site_utils.lxc import constants
22from autotest_lib.site_utils.lxc import unittest_setup
23
24
25def path_exists(path):
26    """Check if path exists.
27
28    If the process is not running with root user, os.path.exists may fail to
29    check if a path owned by root user exists. This function uses command
30    `test -e` to check if path exists.
31
32    @param path: Path to check if it exists.
33
34    @return: True if path exists, otherwise False.
35    """
36    try:
37        utils.run('sudo test -e "%s"' % path)
38        return True
39    except error.CmdError:
40        return False
41
42
43def get_host_ip():
44    """Get the IP address of the host running containers on lxcbr*.
45
46    This function gets the IP address on network interface lxcbr*. The
47    assumption is that lxc uses the network interface started with "lxcbr".
48
49    @return: IP address of the host running containers.
50    """
51    # The kernel publishes symlinks to various network devices in /sys.
52    result = utils.run('ls /sys/class/net', ignore_status=True)
53    # filter out empty strings
54    interface_names = [x for x in result.stdout.split() if x]
55
56    lxc_network = None
57    for name in interface_names:
58        if name.startswith('lxcbr'):
59            lxc_network = name
60            break
61    if not lxc_network:
62        raise error.ContainerError('Failed to find network interface used by '
63                                   'lxc. All existing interfaces are: %s' %
64                                   interface_names)
65    netif = interface.Interface(lxc_network)
66    return netif.ipv4_address
67
68def is_vm():
69    """Check if the process is running in a virtual machine.
70
71    @return: True if the process is running in a virtual machine, otherwise
72             return False.
73    """
74    try:
75        virt = utils.run('sudo -n virt-what').stdout.strip()
76        logging.debug('virt-what output: %s', virt)
77        return bool(virt)
78    except error.CmdError:
79        logging.warn('Package virt-what is not installed, default to assume '
80                     'it is not a virtual machine.')
81        return False
82
83
84def destroy(path, name,
85            force=True, snapshots=False, ignore_status=False, timeout=-1):
86  """
87  Destroy an LXC container.
88
89  @param force: Destroy even if running. Default true.
90  @param snapshots: Destroy all snapshots based on the container. Default false.
91  @param ignore_status: Ignore return code of command. Default false.
92  @param timeout: Seconds to wait for completion. No timeout imposed if the
93    value is negative. Default -1 (no timeout).
94
95  @returns: CmdResult object from the shell command
96  """
97  cmd = 'sudo lxc-destroy -P %s -n %s' % (path, name)
98  if force:
99    cmd += ' -f'
100  if snapshots:
101    cmd += ' -s'
102  if timeout >= 0:
103    return utils.run(cmd, ignore_status=ignore_status, timeout=timeout)
104  else:
105    return utils.run(cmd, ignore_status=ignore_status)
106
107def clone(lxc_path, src_name, new_path, dst_name, snapshot):
108    """Clones a container.
109
110    @param lxc_path: The LXC path of the source container.
111    @param src_name: The name of the source container.
112    @param new_path: The LXC path of the destination container.
113    @param dst_name: The name of the destination container.
114    @param snapshot: Whether or not to create a snapshot clone.
115    """
116    snapshot_arg = '-s' if snapshot and constants.SUPPORT_SNAPSHOT_CLONE else ''
117    # overlayfs is the default clone backend storage. However it is not
118    # supported in Ganeti yet. Use aufs as the alternative.
119    aufs_arg = '-B aufs' if is_vm() and snapshot else ''
120    cmd = (('sudo lxc-copy --lxcpath {lxcpath} --newpath {newpath} '
121                    '--name {name} --newname {newname} {snapshot} {backing}')
122           .format(
123               lxcpath = lxc_path,
124               newpath = new_path,
125               name = src_name,
126               newname = dst_name,
127               snapshot = snapshot_arg,
128               backing = aufs_arg
129           ))
130    utils.run(cmd)
131
132
133@contextmanager
134def TempDir(*args, **kwargs):
135    """Context manager for creating a temporary directory."""
136    tmpdir = tempfile.mkdtemp(*args, **kwargs)
137    try:
138        yield tmpdir
139    finally:
140        shutil.rmtree(tmpdir)
141
142
143class BindMount(object):
144    """Manages setup and cleanup of bind-mounts."""
145    def __init__(self, spec):
146        """Sets up a new bind mount.
147
148        Do not call this directly, use the create or from_existing class
149        methods.
150
151        @param spec: A two-element tuple (dir, mountpoint) where dir is the
152                     location of an existing directory, and mountpoint is the
153                     path under that directory to the desired mount point.
154        """
155        self.spec = spec
156
157
158    def __eq__(self, rhs):
159        if isinstance(rhs, self.__class__):
160            return self.spec == rhs.spec
161        return NotImplemented
162
163
164    def __ne__(self, rhs):
165        return not (self == rhs)
166
167
168    @classmethod
169    def create(cls, src, dst, rename=None, readonly=False):
170        """Creates a new bind mount.
171
172        @param src: The path of the source file/dir.
173        @param dst: The destination directory.  The new mount point will be
174                    ${dst}/${src} unless renamed.  If the mount point does not
175                    already exist, it will be created.
176        @param rename: An optional path to rename the mount.  If provided, the
177                       mount point will be ${dst}/${rename} instead of
178                       ${dst}/${src}.
179        @param readonly: If True, the mount will be read-only.  False by
180                         default.
181
182        @return An object representing the bind-mount, which can be used to
183                clean it up later.
184        """
185        spec = (dst, (rename if rename else src).lstrip(os.path.sep))
186        full_dst = os.path.join(*list(spec))
187
188        if not path_exists(full_dst):
189            utils.run('sudo mkdir -p %s' % full_dst)
190
191        utils.run('sudo mount --bind %s %s' % (src, full_dst))
192        if readonly:
193            utils.run('sudo mount -o remount,ro,bind %s' % full_dst)
194
195        return cls(spec)
196
197
198    @classmethod
199    def from_existing(cls, host_dir, mount_point):
200        """Creates a BindMount for an existing mount point.
201
202        @param host_dir: Path of the host dir hosting the bind-mount.
203        @param mount_point: Full path to the mount point (including the host
204                            dir).
205
206        @return An object representing the bind-mount, which can be used to
207                clean it up later.
208        """
209        spec = (host_dir, os.path.relpath(mount_point, host_dir))
210        return cls(spec)
211
212
213    def cleanup(self):
214        """Cleans up the bind-mount.
215
216        Unmounts the destination, and deletes it if possible. If it was mounted
217        alongside important files, it will not be deleted.
218        """
219        full_dst = os.path.join(*list(self.spec))
220        utils.run('sudo umount %s' % full_dst)
221        # Ignore errors because bind mount locations are sometimes nested
222        # alongside actual file content (e.g. SSPs install into
223        # /usr/local/autotest so rmdir -p will fail for any mounts located in
224        # /usr/local/autotest).
225        utils.run('sudo bash -c "cd %s; rmdir -p --ignore-fail-on-non-empty %s"'
226                  % self.spec)
227
228
229def is_subdir(parent, subdir):
230    """Determines whether the given subdir exists under the given parent dir.
231
232    @param parent: The parent directory.
233    @param subdir: The subdirectory.
234
235    @return True if the subdir exists under the parent dir, False otherwise.
236    """
237    # Append a trailing path separator because commonprefix basically just
238    # performs a prefix string comparison.
239    parent = os.path.join(parent, '')
240    return os.path.commonprefix([parent, subdir]) == parent
241
242
243def sudo_commands(commands):
244    """Takes a list of bash commands and executes them all with one invocation
245    of sudo. Saves ~400 ms per command.
246
247    @param commands: The bash commands, as strings.
248
249    @return The return code of the sudo call.
250    """
251
252    combine = global_config.global_config.get_config_value(
253        'LXC_POOL','combine_sudos', type=bool, default=False)
254
255    if combine:
256        with tempfile.NamedTemporaryFile() as temp:
257            temp.write("set -e\n")
258            temp.writelines([command+"\n" for command in commands])
259            logging.info("Commands to run: %s", str(commands))
260            return utils.run("sudo bash %s" % temp.name)
261    else:
262        for command in commands:
263            result = utils.run("sudo %s" % command)
264
265
266def get_lxc_version():
267    """Gets the current version of lxc if available."""
268    cmd = 'sudo lxc-info --version'
269    result = utils.run(cmd)
270    if result and result.exit_status == 0:
271        version = re.split("[.-]", result.stdout.strip())
272        if len(version) < 3:
273            logging.error("LXC version is not expected format %s.",
274                          result.stdout.strip())
275            return None
276        return_value = []
277        for a in version[:3]:
278            try:
279                return_value.append(int(a))
280            except ValueError:
281                logging.error(("LXC version contains non numerical version "
282                               "number %s (%s)."), a, result.stdout.strip())
283                return None
284        return return_value
285    else:
286        logging.error("Unable to determine LXC version.")
287        return None
288
289class LXCTests(unittest.TestCase):
290    """Thin wrapper to call correct setup for LXC tests."""
291
292    @classmethod
293    def setUpClass(cls):
294        unittest_setup.setup()
295