1import os
2import pathlib
3import zipfile
4import tempfile
5import functools
6import contextlib
7
8
9def from_package(package):
10    """
11    Return a Traversable object for the given package.
12
13    """
14    return fallback_resources(package.__spec__)
15
16
17def fallback_resources(spec):
18    package_directory = pathlib.Path(spec.origin).parent
19    try:
20        archive_path = spec.loader.archive
21        rel_path = package_directory.relative_to(archive_path)
22        return zipfile.Path(archive_path, str(rel_path) + '/')
23    except Exception:
24        pass
25    return package_directory
26
27
28@contextlib.contextmanager
29def _tempfile(reader, suffix=''):
30    # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
31    # blocks due to the need to close the temporary file to work on Windows
32    # properly.
33    fd, raw_path = tempfile.mkstemp(suffix=suffix)
34    try:
35        os.write(fd, reader())
36        os.close(fd)
37        yield pathlib.Path(raw_path)
38    finally:
39        try:
40            os.remove(raw_path)
41        except FileNotFoundError:
42            pass
43
44
45@functools.singledispatch
46@contextlib.contextmanager
47def as_file(path):
48    """
49    Given a Traversable object, return that object as a
50    path on the local file system in a context manager.
51    """
52    with _tempfile(path.read_bytes, suffix=path.name) as local:
53        yield local
54
55
56@as_file.register(pathlib.Path)
57@contextlib.contextmanager
58def _(path):
59    """
60    Degenerate behavior for pathlib.Path objects.
61    """
62    yield path
63