1"""
2colorLib.table_builder: Generic helper for filling in BaseTable derivatives from tuples and maps and such.
3
4"""
5
6import collections
7import enum
8from fontTools.ttLib.tables.otBase import (
9    BaseTable,
10    FormatSwitchingBaseTable,
11    UInt8FormatSwitchingBaseTable,
12)
13from fontTools.ttLib.tables.otConverters import (
14    ComputedInt,
15    SimpleValue,
16    Struct,
17    Short,
18    UInt8,
19    UShort,
20    VarInt16,
21    VarUInt16,
22    IntValue,
23    FloatValue,
24)
25from fontTools.misc.roundTools import otRound
26
27
28class BuildCallback(enum.Enum):
29    """Keyed on (BEFORE_BUILD, class[, Format if available]).
30    Receives (dest, source).
31    Should return (dest, source), which can be new objects.
32    """
33
34    BEFORE_BUILD = enum.auto()
35
36    """Keyed on (AFTER_BUILD, class[, Format if available]).
37    Receives (dest).
38    Should return dest, which can be a new object.
39    """
40    AFTER_BUILD = enum.auto()
41
42    """Keyed on (CREATE_DEFAULT, class).
43    Receives no arguments.
44    Should return a new instance of class.
45    """
46    CREATE_DEFAULT = enum.auto()
47
48
49def _assignable(convertersByName):
50    return {k: v for k, v in convertersByName.items() if not isinstance(v, ComputedInt)}
51
52
53def convertTupleClass(tupleClass, value):
54    if isinstance(value, tupleClass):
55        return value
56    if isinstance(value, tuple):
57        return tupleClass(*value)
58    return tupleClass(value)
59
60
61def _isNonStrSequence(value):
62    return isinstance(value, collections.abc.Sequence) and not isinstance(value, str)
63
64
65def _set_format(dest, source):
66    if _isNonStrSequence(source):
67        assert len(source) > 0, f"{type(dest)} needs at least format from {source}"
68        dest.Format = source[0]
69        source = source[1:]
70    elif isinstance(source, collections.abc.Mapping):
71        assert "Format" in source, f"{type(dest)} needs at least Format from {source}"
72        dest.Format = source["Format"]
73    else:
74        raise ValueError(f"Not sure how to populate {type(dest)} from {source}")
75
76    assert isinstance(
77        dest.Format, collections.abc.Hashable
78    ), f"{type(dest)} Format is not hashable: {dest.Format}"
79    assert (
80        dest.Format in dest.convertersByName
81    ), f"{dest.Format} invalid Format of {cls}"
82
83    return source
84
85
86class TableBuilder:
87    """
88    Helps to populate things derived from BaseTable from maps, tuples, etc.
89
90    A table of lifecycle callbacks may be provided to add logic beyond what is possible
91    based on otData info for the target class. See BuildCallbacks.
92    """
93
94    def __init__(self, callbackTable=None):
95        if callbackTable is None:
96            callbackTable = {}
97        self._callbackTable = callbackTable
98
99    def _convert(self, dest, field, converter, value):
100        tupleClass = getattr(converter, "tupleClass", None)
101        enumClass = getattr(converter, "enumClass", None)
102
103        if tupleClass:
104            value = convertTupleClass(tupleClass, value)
105
106        elif enumClass:
107            if isinstance(value, enumClass):
108                pass
109            elif isinstance(value, str):
110                try:
111                    value = getattr(enumClass, value.upper())
112                except AttributeError:
113                    raise ValueError(f"{value} is not a valid {enumClass}")
114            else:
115                value = enumClass(value)
116
117        elif isinstance(converter, IntValue):
118            value = otRound(value)
119        elif isinstance(converter, FloatValue):
120            value = float(value)
121
122        elif isinstance(converter, Struct):
123            if converter.repeat:
124                if _isNonStrSequence(value):
125                    value = [self.build(converter.tableClass, v) for v in value]
126                else:
127                    value = [self.build(converter.tableClass, value)]
128                setattr(dest, converter.repeat, len(value))
129            else:
130                value = self.build(converter.tableClass, value)
131        elif callable(converter):
132            value = converter(value)
133
134        setattr(dest, field, value)
135
136    def build(self, cls, source):
137        assert issubclass(cls, BaseTable)
138
139        if isinstance(source, cls):
140            return source
141
142        callbackKey = (cls,)
143        dest = self._callbackTable.get(
144            (BuildCallback.CREATE_DEFAULT,) + callbackKey, lambda: cls()
145        )()
146        assert isinstance(dest, cls)
147
148        convByName = _assignable(cls.convertersByName)
149        skippedFields = set()
150
151        # For format switchers we need to resolve converters based on format
152        if issubclass(cls, FormatSwitchingBaseTable):
153            source = _set_format(dest, source)
154
155            convByName = _assignable(convByName[dest.Format])
156            skippedFields.add("Format")
157            callbackKey = (cls, dest.Format)
158
159        # Convert sequence => mapping so before thunk only has to handle one format
160        if _isNonStrSequence(source):
161            # Sequence (typically list or tuple) assumed to match fields in declaration order
162            assert len(source) <= len(
163                convByName
164            ), f"Sequence of {len(source)} too long for {cls}; expected <= {len(convByName)} values"
165            source = dict(zip(convByName.keys(), source))
166
167        dest, source = self._callbackTable.get(
168            (BuildCallback.BEFORE_BUILD,) + callbackKey, lambda d, s: (d, s)
169        )(dest, source)
170
171        if isinstance(source, collections.abc.Mapping):
172            for field, value in source.items():
173                if field in skippedFields:
174                    continue
175                converter = convByName.get(field, None)
176                if not converter:
177                    raise ValueError(
178                        f"Unrecognized field {field} for {cls}; expected one of {sorted(convByName.keys())}"
179                    )
180                self._convert(dest, field, converter, value)
181        else:
182            # let's try as a 1-tuple
183            dest = self.build(cls, (source,))
184
185        dest = self._callbackTable.get(
186            (BuildCallback.AFTER_BUILD,) + callbackKey, lambda d: d
187        )(dest)
188
189        return dest
190
191
192class TableUnbuilder:
193    def __init__(self, callbackTable=None):
194        if callbackTable is None:
195            callbackTable = {}
196        self._callbackTable = callbackTable
197
198    def unbuild(self, table):
199        assert isinstance(table, BaseTable)
200
201        source = {}
202
203        callbackKey = (type(table),)
204        if isinstance(table, FormatSwitchingBaseTable):
205            source["Format"] = int(table.Format)
206            callbackKey += (table.Format,)
207
208        for converter in table.getConverters():
209            if isinstance(converter, ComputedInt):
210                continue
211            value = getattr(table, converter.name)
212
213            tupleClass = getattr(converter, "tupleClass", None)
214            enumClass = getattr(converter, "enumClass", None)
215            if tupleClass:
216                source[converter.name] = tuple(value)
217            elif enumClass:
218                source[converter.name] = value.name.lower()
219            elif isinstance(converter, Struct):
220                if converter.repeat:
221                    source[converter.name] = [self.unbuild(v) for v in value]
222                else:
223                    source[converter.name] = self.unbuild(value)
224            elif isinstance(converter, SimpleValue):
225                # "simple" values (e.g. int, float, str) need no further un-building
226                source[converter.name] = value
227            else:
228                raise NotImplementedError(
229                    "Don't know how unbuild {value!r} with {converter!r}"
230                )
231
232        source = self._callbackTable.get(callbackKey, lambda s: s)(source)
233
234        return source
235