1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for compiling and importing Python protos on the fly."""
15
16from collections.abc import Mapping
17import importlib.util
18import logging
19import os
20from pathlib import Path
21import subprocess
22import shlex
23import tempfile
24from types import ModuleType
25from typing import (Dict, Generic, Iterable, Iterator, List, NamedTuple, Set,
26                    Tuple, TypeVar, Union)
27
28_LOG = logging.getLogger(__name__)
29
30PathOrStr = Union[Path, str]
31
32
33def compile_protos(
34    output_dir: PathOrStr,
35    proto_files: Iterable[PathOrStr],
36    includes: Iterable[PathOrStr] = ()) -> None:
37    """Compiles proto files for Python by invoking the protobuf compiler.
38
39    Proto files not covered by one of the provided include paths will have their
40    directory added as an include path.
41    """
42    proto_paths: List[Path] = [Path(f).resolve() for f in proto_files]
43    include_paths: Set[Path] = set(Path(d).resolve() for d in includes)
44
45    for path in proto_paths:
46        if not any(include in path.parents for include in include_paths):
47            include_paths.add(path.parent)
48
49    cmd: Tuple[PathOrStr, ...] = (
50        'protoc',
51        '--experimental_allow_proto3_optional',
52        '--python_out',
53        os.path.abspath(output_dir),
54        *(f'-I{d}' for d in include_paths),
55        *proto_paths,
56    )
57
58    _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd))
59    process = subprocess.run(cmd, capture_output=True)
60
61    if process.returncode:
62        _LOG.error('protoc invocation failed!\n%s\n%s',
63                   ' '.join(shlex.quote(str(c)) for c in cmd),
64                   process.stderr.decode())
65        process.check_returncode()
66
67
68def _import_module(name: str, path: str) -> ModuleType:
69    spec = importlib.util.spec_from_file_location(name, path)
70    module = importlib.util.module_from_spec(spec)
71    spec.loader.exec_module(module)  # type: ignore[union-attr]
72    return module
73
74
75def import_modules(directory: PathOrStr) -> Iterator:
76    """Imports modules in a directory and yields them."""
77    parent = os.path.dirname(directory)
78
79    for dirpath, _, files in os.walk(directory):
80        path_parts = os.path.relpath(dirpath, parent).split(os.sep)
81
82        for file in files:
83            name, ext = os.path.splitext(file)
84
85            if ext == '.py':
86                yield _import_module(f'{".".join(path_parts)}.{name}',
87                                     os.path.join(dirpath, file))
88
89
90def compile_and_import(proto_files: Iterable[PathOrStr],
91                       includes: Iterable[PathOrStr] = (),
92                       output_dir: PathOrStr = None) -> Iterator:
93    """Compiles protos and imports their modules; yields the proto modules.
94
95    Args:
96      proto_files: paths to .proto files to compile
97      includes: include paths to use for .proto compilation
98      output_dir: where to place the generated modules; a temporary directory is
99          used if omitted
100
101    Yields:
102      the generated protobuf Python modules
103    """
104
105    if output_dir:
106        compile_protos(output_dir, proto_files, includes)
107        yield from import_modules(output_dir)
108    else:
109        with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir:
110            compile_protos(tempdir, proto_files, includes)
111            yield from import_modules(tempdir)
112
113
114def compile_and_import_file(proto_file: PathOrStr,
115                            includes: Iterable[PathOrStr] = (),
116                            output_dir: PathOrStr = None):
117    """Compiles and imports the module for a single .proto file."""
118    return next(iter(compile_and_import([proto_file], includes, output_dir)))
119
120
121def compile_and_import_strings(contents: Iterable[str],
122                               includes: Iterable[PathOrStr] = (),
123                               output_dir: PathOrStr = None) -> Iterator:
124    """Compiles protos in one or more strings."""
125
126    if isinstance(contents, str):
127        contents = [contents]
128
129    with tempfile.TemporaryDirectory(prefix='proto_sources_') as path:
130        protos = []
131
132        for proto in contents:
133            # Use a hash of the proto so the same contents map to the same file
134            # name. The protobuf package complains if it seems the same contents
135            # in files with different names.
136            protos.append(Path(path, f'protobuf_{hash(proto):x}.proto'))
137            protos[-1].write_text(proto)
138
139        yield from compile_and_import(protos, includes, output_dir)
140
141
142T = TypeVar('T')
143
144
145class _NestedPackage(Generic[T]):
146    """Facilitates navigating protobuf packages as attributes."""
147    def __init__(self, package: str):
148        self._packages: Dict[str, _NestedPackage[T]] = {}
149        self._items: List[T] = []
150        self._package = package
151
152    def _add_package(self, subpackage: str, package: '_NestedPackage') -> None:
153        self._packages[subpackage] = package
154
155    def _add_item(self, item) -> None:
156        if item not in self._items:  # Don't store the same item multiple times.
157            self._items.append(item)
158
159    def __getattr__(self, attr: str):
160        """Look up subpackages or package members."""
161        if attr in self._packages:
162            return self._packages[attr]
163
164        for item in self._items:
165            if hasattr(item, attr):
166                return getattr(item, attr)
167
168        raise AttributeError(
169            f'Proto package "{self._package}" does not contain "{attr}"')
170
171    def __getitem__(self, subpackage: str) -> '_NestedPackage[T]':
172        """Support accessing nested packages by name."""
173        result = self
174
175        for package in subpackage.split('.'):
176            result = result._packages[package]
177
178        return result
179
180    def __dir__(self) -> List[str]:
181        """List subpackages and members of modules as attributes."""
182        attributes = list(self._packages)
183
184        for item in self._items:
185            for attr, value in vars(item).items():
186                # Exclude private variables and modules from dir().
187                if not attr.startswith('_') and not isinstance(
188                        value, ModuleType):
189                    attributes.append(attr)
190
191        return attributes
192
193    def __iter__(self) -> Iterator['_NestedPackage[T]']:
194        """Iterate over nested packages."""
195        return iter(self._packages.values())
196
197    def __repr__(self) -> str:
198        msg = [f'ProtoPackage({self._package!r}']
199
200        public_members = [
201            i for i in vars(self)
202            if i not in self._packages and not i.startswith('_')
203        ]
204        if public_members:
205            msg.append(f'members={str(public_members)}')
206
207        if self._packages:
208            msg.append(f'subpackages={str(list(self._packages))}')
209
210        return ', '.join(msg) + ')'
211
212    def __str__(self) -> str:
213        return self._package
214
215
216class Packages(NamedTuple):
217    """Items in a protobuf package structure; returned from as_package."""
218    items_by_package: Dict[str, List]
219    packages: _NestedPackage
220
221
222def as_packages(items: Iterable[Tuple[str, T]],
223                packages: Packages = None) -> Packages:
224    """Places items in a proto-style package structure navigable by attributes.
225
226    Args:
227      items: (package, item) tuples to insert into the package structure
228      packages: if provided, update this Packages instead of creating a new one
229    """
230    if packages is None:
231        packages = Packages({}, _NestedPackage(''))
232
233    for package, item in items:
234        packages.items_by_package.setdefault(package, []).append(item)
235
236        entry = packages.packages
237        subpackages = package.split('.')
238
239        # pylint: disable=protected-access
240        for i, subpackage in enumerate(subpackages, 1):
241            if subpackage not in entry._packages:
242                entry._add_package(subpackage,
243                                   _NestedPackage('.'.join(subpackages[:i])))
244
245            entry = entry._packages[subpackage]
246
247        entry._add_item(item)
248        # pylint: enable=protected-access
249
250    return packages
251
252
253PathOrModule = Union[str, Path, ModuleType]
254
255
256class Library:
257    """A collection of protocol buffer modules sorted by package.
258
259    In Python, each .proto file is compiled into a Python module. The Library
260    class makes it simple to navigate a collection of Python modules
261    corresponding to .proto files, without relying on the location of these
262    compiled modules.
263
264    Proto messages and other types can be directly accessed by their protocol
265    buffer package name. For example, the foo.bar.Baz message can be accessed
266    in a Library called `protos` as:
267
268      protos.packages.foo.bar.Baz
269
270    A Library also provides the modules_by_package dictionary, for looking up
271    the list of modules in a particular package, and the modules() generator
272    for iterating over all modules.
273    """
274    @classmethod
275    def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library':
276        """Creates a Library from paths to proto files or proto modules."""
277        paths: List[PathOrStr] = []
278        modules: List[ModuleType] = []
279
280        for proto in protos:
281            if isinstance(proto, (Path, str)):
282                paths.append(proto)
283            else:
284                modules.append(proto)
285
286        if paths:
287            modules += compile_and_import(paths)
288        return Library(modules)
289
290    @classmethod
291    def from_strings(cls,
292                     contents: Iterable[str],
293                     includes: Iterable[PathOrStr] = (),
294                     output_dir: PathOrStr = None) -> 'Library':
295        """Creates a proto library from protos in the provided strings."""
296        return cls(compile_and_import_strings(contents, includes, output_dir))
297
298    def __init__(self, modules: Iterable[ModuleType]):
299        """Constructs a Library from an iterable of modules.
300
301        A Library can be constructed with modules dynamically compiled by
302        compile_and_import. For example:
303
304            protos = Library(compile_and_import(list_of_proto_files))
305        """
306        self.modules_by_package, self.packages = as_packages(
307            (m.DESCRIPTOR.package, m)  # type: ignore[attr-defined]
308            for m in modules)
309
310    def modules(self) -> Iterable:
311        """Iterates over all protobuf modules in this library."""
312        for module_list in self.modules_by_package.values():
313            yield from module_list
314
315    def messages(self) -> Iterable:
316        """Iterates over all protobuf messages in this library."""
317        for module in self.modules():
318            yield from _nested_messages(
319                module, module.DESCRIPTOR.message_types_by_name)
320
321
322def _nested_messages(scope, message_names: Iterable[str]) -> Iterator:
323    for name in message_names:
324        msg = getattr(scope, name)
325        yield msg
326        yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name)
327
328
329def _repr_char(char: int) -> str:
330    r"""Returns an ASCII char or the \x code for non-printable values."""
331    if ord(' ') <= char <= ord('~'):
332        return r"\'" if chr(char) == "'" else chr(char)
333
334    return f'\\x{char:02X}'
335
336
337def bytes_repr(value: bytes) -> str:
338    """Prints bytes as mixed ASCII only if at least half are printable."""
339    ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value)
340    if ascii_char_count >= len(value) / 2:
341        contents = ''.join(_repr_char(c) for c in value)
342    else:
343        contents = ''.join(f'\\x{c:02X}' for c in value)
344
345    return f"b'{contents}'"
346
347
348def _field_repr(field, value) -> str:
349    if field.type == field.TYPE_ENUM:
350        try:
351            enum = field.enum_type.values_by_number[value]
352            return f'{field.enum_type.full_name}.{enum.name}'
353        except KeyError:
354            return repr(value)
355
356    if field.type == field.TYPE_MESSAGE:
357        return proto_repr(value)
358
359    if field.type == field.TYPE_BYTES:
360        return bytes_repr(value)
361
362    return repr(value)
363
364
365def _proto_repr(message) -> Iterator[str]:
366    for field in message.DESCRIPTOR.fields:
367        value = getattr(message, field.name)
368
369        # Skip fields that are not present.
370        try:
371            if not message.HasField(field.name):
372                continue
373        except ValueError:
374            # Skip default-valued fields that don't support HasField.
375            if (field.label != field.LABEL_REPEATED
376                    and value == field.default_value):
377                continue
378
379        if field.label == field.LABEL_REPEATED:
380            if not value:
381                continue
382
383            if isinstance(value, Mapping):
384                key_desc, value_desc = field.message_type.fields
385                values = ', '.join(
386                    f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}'
387                    for k, v in value.items())
388                yield f'{field.name}={{{values}}}'
389            else:
390                values = ', '.join(_field_repr(field, v) for v in value)
391                yield f'{field.name}=[{values}]'
392        else:
393            yield f'{field.name}={_field_repr(field, value)}'
394
395
396def proto_repr(message) -> str:
397    """Creates a repr-like string for a protobuf."""
398    return f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})'
399