1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides general purpose plugin functionality.
15
16As used in this module, a plugin is a Python object associated with a name.
17Plugins are registered in a Registry. The plugin object is typically a function,
18but can be anything.
19
20Plugins may be loaded in a variety of ways:
21
22- Listed in a plugins file in the file system (e.g. as "name module target").
23- Registered in a Python file using a decorator (@my_registry.plugin).
24- Registered directly or by name with function calls on a registry object.
25
26This functionality can be used to create plugins for command line tools,
27interactive consoles, or anything else. Pigweed's pw command uses this module
28for its plugins.
29"""
30
31import collections
32import collections.abc
33import importlib
34import inspect
35import logging
36from pathlib import Path
37import pkgutil
38import sys
39from textwrap import TextWrapper
40import types
41from typing import Any, Callable, Dict, List, Iterable, Iterator, Optional, Set
42
43_LOG = logging.getLogger(__name__)
44_BUILT_IN = '<built-in>'
45
46
47class Error(Exception):
48    """Indicates that a plugin is invalid or cannot be registered."""
49    def __str__(self):
50        """Displays the error as a string, including the __cause__ if present.
51
52        Adding __cause__ gives useful context without displaying a backtrace.
53        """
54        if self.__cause__ is None:
55            return super().__str__()
56
57        return (f'{super().__str__()} '
58                f'({type(self.__cause__).__name__}: {self.__cause__})')
59
60
61def _get_module(member: object) -> types.ModuleType:
62    """Gets the module or a dummy module if the module isn't found."""
63    module = inspect.getmodule(member)
64    return module if module else types.ModuleType('<unknown>')
65
66
67class Plugin:
68    """Represents a Python entity registered as a plugin.
69
70    Each plugin resolves to a Python object, typically a function.
71    """
72    @classmethod
73    def from_name(cls, name: str, module_name: str, member_name: str,
74                  source: Optional[Path]) -> 'Plugin':
75        """Creates a plugin by module and attribute name.
76
77        Args:
78          name: the name of the plugin
79          module_name: Python module name (e.g. 'foo_pkg.bar')
80          member_name: the name of the member in the module
81          source: path to the plugins file that declared this plugin, if any
82        """
83
84        # Attempt to access the module and member. Catch any errors that might
85        # occur, since a bad plugin shouldn't be a fatal error.
86        try:
87            module = importlib.import_module(module_name)
88        except Exception as err:
89            raise Error(f'Failed to import module "{module_name}"') from err
90
91        try:
92            member = getattr(module, member_name)
93        except AttributeError as err:
94            raise Error(
95                f'"{module_name}.{member_name}" does not exist') from err
96
97        return cls(name, member, source)
98
99    def __init__(self, name: str, target: Any, source: Path = None) -> None:
100        """Creates a plugin for the provided target."""
101        self.name = name
102        self._module = _get_module(target)
103        self.target = target
104        self.source = source
105
106    @property
107    def target_name(self) -> str:
108        return (f'{self._module.__name__}.'
109                f'{getattr(self.target, "__name__", self.target)}')
110
111    @property
112    def source_name(self) -> str:
113        return _BUILT_IN if self.source is None else str(self.source)
114
115    def run_with_argv(self, argv: Iterable[str]) -> int:
116        """Sets sys.argv and calls the plugin function.
117
118        This is used to call a plugin as if from the command line.
119        """
120        original_sys_argv = sys.argv
121        sys.argv = [f'pw {self.name}', *argv]
122
123        try:
124            return self.target()
125        finally:
126            sys.argv = original_sys_argv
127
128    def help(self, full: bool = False) -> str:
129        """Returns a description of this plugin from its docstring."""
130        docstring = self.target.__doc__ or self._module.__doc__ or ''
131        return docstring if full else next(iter(docstring.splitlines()), '')
132
133    def details(self, full: bool = False) -> Iterator[str]:
134        yield f'help    {self.help(full=full)}'
135        yield f'module  {self._module.__name__}'
136        yield f'target  {getattr(self.target, "__name__", self.target)}'
137        yield f'source  {self.source_name}'
138
139    def __repr__(self) -> str:
140        return (f'{self.__class__.__name__}(name={self.name!r}, '
141                f'target={self.target_name}'
142                f'{f", source={self.source_name!r}" if self.source else ""})')
143
144
145def callable_with_no_args(plugin: Plugin) -> None:
146    """Checks that a plugin is callable without arguments.
147
148    May be used for the validator argument to Registry.
149    """
150    try:
151        params = inspect.signature(plugin.target).parameters
152    except TypeError:
153        raise Error('Plugin functions must be callable, but '
154                    f'{plugin.target_name} is a '
155                    f'{type(plugin.target).__name__}')
156
157    positional = sum(p.default == p.empty for p in params.values())
158    if positional:
159        raise Error(f'Plugin functions cannot have any required positional '
160                    f'arguments, but {plugin.target_name} has {positional}')
161
162
163class Registry(collections.abc.Mapping):
164    """Manages a set of plugins from Python modules or plugins files."""
165    def __init__(self,
166                 validator: Callable[[Plugin], Any] = lambda _: None) -> None:
167        """Creates a new, empty plugins registry.
168
169        Args:
170          validator: Function that checks whether a plugin is valid and should
171              be registered. Must raise plugins.Error is the plugin is invalid.
172        """
173
174        self._registry: Dict[str, Plugin] = {}
175        self._sources: Set[Path] = set()  # Paths to plugins files
176        self._errors: Dict[str,
177                           List[Exception]] = collections.defaultdict(list)
178        self._validate_plugin = validator
179
180    def __getitem__(self, name: str) -> Plugin:
181        """Accesses a plugin by name; raises KeyError if it does not exist."""
182        if name in self._registry:
183            return self._registry[name]
184
185        if name in self._errors:
186            raise KeyError(f'Registration for "{name}" failed: ' +
187                           ', '.join(str(e) for e in self._errors[name]))
188
189        raise KeyError(f'The plugin "{name}" has not been registered')
190
191    def __iter__(self) -> Iterator[str]:
192        return iter(self._registry)
193
194    def __len__(self) -> int:
195        return len(self._registry)
196
197    def errors(self) -> Dict[str, List[Exception]]:
198        return self._errors
199
200    def run_with_argv(self, name: str, argv: Iterable[str]) -> int:
201        """Runs a plugin by name, setting sys.argv to the provided args.
202
203        This is used to run a command as if it were executed directly from the
204        command line. The plugin is expected to return an int.
205
206        Raises:
207          KeyError if plugin is not registered.
208        """
209        return self[name].run_with_argv(argv)
210
211    def _should_register(self, plugin: Plugin) -> bool:
212        """Determines and logs if a plugin should be registered or not.
213
214        Some errors are exceptions, others are not.
215        """
216
217        if plugin.name in self._registry and plugin.source is None:
218            raise Error(
219                f'Attempted to register built-in plugin "{plugin.name}", but '
220                'a plugin with that name was previously registered '
221                f'({self[plugin.name]})!')
222
223        # Run the user-provided validation function, which raises exceptions
224        # if there are errors.
225        self._validate_plugin(plugin)
226
227        existing = self._registry.get(plugin.name)
228
229        if existing is None:
230            return True
231
232        if existing.source is None:
233            _LOG.debug('%s: Overriding built-in plugin "%s" with %s',
234                       plugin.source_name, plugin.name, plugin.target_name)
235            return True
236
237        if plugin.source != existing.source:
238            _LOG.debug(
239                '%s: The plugin "%s" was previously registered in %s; '
240                'ignoring registration as %s', plugin.source_name, plugin.name,
241                self._registry[plugin.name].source, plugin.target_name)
242        elif plugin.source not in self._sources:
243            _LOG.warning(
244                '%s: "%s" is registered file multiple times in this file! '
245                'Only the first registration takes effect', plugin.source_name,
246                plugin.name)
247
248        return False
249
250    def register(self, name: str, target: Any) -> Optional[Plugin]:
251        """Registers an object as a plugin."""
252        return self._register(Plugin(name, target, None))
253
254    def register_by_name(self,
255                         name: str,
256                         module_name: str,
257                         member_name: str,
258                         source: Path = None) -> Optional[Plugin]:
259        """Registers an object from its module and name as a plugin."""
260        return self._register(
261            Plugin.from_name(name, module_name, member_name, source))
262
263    def _register(self, plugin: Plugin) -> Optional[Plugin]:
264        # Prohibit functions not from a plugins file from overriding others.
265        if not self._should_register(plugin):
266            return None
267
268        self._registry[plugin.name] = plugin
269        _LOG.debug('%s: Registered plugin "%s" for %s', plugin.source_name,
270                   plugin.name, plugin.target_name)
271
272        return plugin
273
274    def register_file(self, path: Path) -> None:
275        """Registers plugins from a plugins file.
276
277        Any exceptions raised from parsing the file are caught and logged.
278        """
279        with path.open() as contents:
280            for lineno, line in enumerate(contents, 1):
281                line = line.strip()
282                if not line or line.startswith('#'):
283                    continue
284
285                try:
286                    name, module, function = line.split()
287                except ValueError as err:
288                    self._errors[line.strip()].append(Error(err))
289                    _LOG.error(
290                        '%s:%d: Failed to parse plugin entry "%s": '
291                        'Expected 3 items (name, module, function), '
292                        'got %d', path, lineno, line, len(line.split()))
293                    continue
294
295                try:
296                    self.register_by_name(name, module, function, path)
297                except Error as err:
298                    self._errors[name].append(err)
299                    _LOG.error('%s: Failed to register plugin "%s": %s', path,
300                               name, err)
301
302        self._sources.add(path)
303
304    def register_directory(self,
305                           directory: Path,
306                           file_name: str,
307                           restrict_to: Path = None) -> None:
308        """Finds and registers plugins from plugins files in a directory.
309
310        Args:
311          directory: The directory from which to start searching up.
312          file_name: The name of plugins files to look for.
313          restrict_to: If provided, do not search higher than this directory.
314        """
315        for path in find_all_in_parents(file_name, directory):
316            if not path.is_file():
317                continue
318
319            if restrict_to is not None and restrict_to not in path.parents:
320                _LOG.debug(
321                    "Skipping plugins file %s because it's outside of %s",
322                    path, restrict_to)
323                continue
324
325            _LOG.debug('Found plugins file %s', path)
326            self.register_file(path)
327
328    def short_help(self) -> str:
329        """Returns a help string for the registered plugins."""
330        width = max(len(name)
331                    for name in self._registry) + 1 if self._registry else 1
332        help_items = '\n'.join(
333            f'  {name:{width}} {plugin.help()}'
334            for name, plugin in sorted(self._registry.items()))
335        return f'supported plugins:\n{help_items}'
336
337    def detailed_help(self, plugins: Iterable[str] = ()) -> Iterator[str]:
338        """Yields lines of detailed information about commands."""
339        if not plugins:
340            plugins = list(self._registry)
341
342        yield '\ndetailed plugin information:'
343
344        wrapper = TextWrapper(width=80,
345                              initial_indent='   ',
346                              subsequent_indent=' ' * 11)
347
348        plugins = sorted(plugins)
349        for plugin in plugins:
350            yield f'  [{plugin}]'
351
352            try:
353                for line in self[plugin].details(full=len(plugins) == 1):
354                    yield wrapper.fill(line)
355            except KeyError as err:
356                yield wrapper.fill(f'error   {str(err)[1:-1]}')
357
358            yield ''
359
360        yield 'Plugins files:'
361
362        if self._sources:
363            yield from (f'  [{i}] {file}'
364                        for i, file in enumerate(self._sources, 1))
365        else:
366            yield '  (none found)'
367
368    def plugin(self,
369               function: Callable = None,
370               *,
371               name: str = None) -> Callable[[Callable], Callable]:
372        """Decorator that registers a function with this plugin registry."""
373        def decorator(function: Callable) -> Callable:
374            self.register(function.__name__ if name is None else name,
375                          function)
376            return function
377
378        if function is None:
379            return decorator
380
381        self.register(function.__name__, function)
382        return function
383
384
385def find_in_parents(name: str, path: Path) -> Optional[Path]:
386    """Searches parent directories of the path for a file or directory."""
387    path = path.resolve()
388
389    while not path.joinpath(name).exists():
390        path = path.parent
391
392        if path.samefile(path.parent):
393            return None
394
395    return path.joinpath(name)
396
397
398def find_all_in_parents(name: str, path: Path) -> Iterator[Path]:
399    """Searches all parent directories of the path for files or directories."""
400
401    while True:
402        result = find_in_parents(name, path)
403        if result is None:
404            return
405
406        yield result
407        path = result.parent.parent
408
409
410def import_submodules(module: types.ModuleType,
411                      recursive: bool = False) -> None:
412    """Imports the submodules of a package.
413
414    This can be used to collect plugins registered with a decorator from a
415    directory.
416    """
417    path = module.__path__  # type: ignore[attr-defined]
418    if recursive:
419        modules = pkgutil.walk_packages(path, module.__name__ + '.')
420    else:
421        modules = pkgutil.iter_modules(path, module.__name__ + '.')
422
423    for info in modules:
424        importlib.import_module(info.name)
425