1#!/usr/bin/env python3
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the 'License');
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an 'AS IS' BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import collections
18import os
19
20
21class InvalidParamError(Exception):
22    pass
23
24
25class ConfigWrapper(collections.UserDict):
26    """Class representing a test or preparer config."""
27
28    def __init__(self, config=None):
29        """Initialize a ConfigWrapper
30
31        Args:
32            config: A dict representing the preparer/test parameters
33        """
34        if config is None:
35            config = {}
36        super().__init__(
37            {
38                key: (ConfigWrapper(val) if isinstance(val, dict) else val)
39                for key, val in config.items()
40            }
41        )
42
43    def get(self, param_name, default=None, verify_fn=lambda _: True,
44            failure_msg=''):
45        """Get parameter from config, verifying that the value is valid
46        with verify_fn.
47
48        Args:
49            param_name: Name of the param to fetch
50            default: Default value of param.
51            verify_fn: Callable to verify the param value. If it returns False,
52                an exception will be raised.
53            failure_msg: Exception message upon verify_fn failure.
54        """
55        result = self.data.get(param_name, default)
56        if not verify_fn(result):
57            raise InvalidParamError('Invalid value "%s" for param %s. %s'
58                                    % (result, param_name, failure_msg))
59        return result
60
61    def get_config(self, param_name):
62        """Get a sub-config from config. Returns an empty ConfigWrapper if no
63        such sub-config is found.
64        """
65        return self.get(param_name, default=ConfigWrapper())
66
67    def get_int(self, param_name, default=0):
68        """Get integer parameter from config. Will raise an exception
69        if result is not of type int.
70        """
71        return self.get(param_name, default=default,
72                        verify_fn=lambda val: type(val) is int,
73                        failure_msg='Param must be of type int.')
74
75    def get_numeric(self, param_name, default=0):
76        """Get int or float parameter from config. Will raise an exception if
77        result is not of type int or float.
78        """
79        return self.get(param_name, default=default,
80                        verify_fn=lambda val: type(val) in (int, float),
81                        failure_msg='Param must be of type int or float.')
82