1#
2# Copyright (C) 2020 Collabora, Ltd.
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22
23import sys
24import itertools
25from isa_parse import parse_instructions, opname_to_c, expand_states
26from mako.template import Template
27
28instructions = parse_instructions(sys.argv[1])
29
30# Constructs a reserved mask for a derived to cull impossible encodings
31
32def reserved_mask(derived):
33    ((pos, width), opts) = derived
34    reserved = [x is None for x in opts]
35    mask = sum([(y << x) for x, y in enumerate(reserved)])
36    return (pos, width, mask)
37
38def reserved_masks(op):
39    masks = [reserved_mask(m) for m in op[2].get("derived", [])]
40    return [m for m in masks if m[2] != 0]
41
42# To decode instructions, pattern match based on the rules:
43#
44# 1. Execution unit (FMA or ADD) must line up.
45# 2. All exact bits must match.
46# 3. No fields should be reserved in a legal encoding.
47# 4. Tiebreaker: Longer exact masks (greater unsigned bitwise inverses) win.
48#
49# To implement, filter the execution unit and check for exact bits in
50# descending order of exact mask length.  Check for reserved fields per
51# candidate and succeed if it matches.
52# found.
53
54def decode_op(instructions, is_fma):
55    # Filter out the desired execution unit
56    options = [n for n in instructions.keys() if (n[0] == '*') == is_fma]
57
58    # Sort by exact masks, descending
59    MAX_MASK = (1 << (23 if is_fma else 20)) - 1
60    options.sort(key = lambda n: (MAX_MASK ^ instructions[n][2]["exact"][0]))
61
62    # Map to what we need to template
63    mapped = [(opname_to_c(op), instructions[op][2]["exact"], reserved_masks(instructions[op])) for op in options]
64
65    # Generate checks in order
66    template = """void
67bi_disasm_${unit}(FILE *fp, unsigned bits, struct bifrost_regs *srcs, struct bifrost_regs *next_regs, unsigned staging_register, unsigned branch_offset, struct bi_constants *consts, bool last)
68{
69% for (i, (name, (emask, ebits), derived)) in enumerate(options):
70% if len(derived) > 0:
71    ${"else " if i > 0 else ""}if (unlikely(((bits & ${hex(emask)}) == ${hex(ebits)})
72% for (pos, width, reserved) in derived:
73        && !(${hex(reserved)} & (1 << _BITS(bits, ${pos}, ${width})))
74% endfor
75    ))
76% else:
77    ${"else " if i > 0 else ""}if (unlikely(((bits & ${hex(emask)}) == ${hex(ebits)})))
78% endif
79        bi_disasm_${name}(fp, bits, srcs, next_regs, staging_register, branch_offset, consts, last);
80% endfor
81    else
82        fprintf(fp, "INSTR_INVALID_ENC ${unit} %X\\n", bits);
83}"""
84
85    return Template(template).render(options = mapped, unit = "fma" if is_fma else "add")
86
87# Decoding emits a series of function calls to e.g. `fma_fadd_v2f16`. We need to
88# emit functions to disassemble a single decoded instruction in a particular
89# state. Sync prototypes to avoid moves when calling.
90
91disasm_op_template = Template("""static void
92bi_disasm_${c_name}(FILE *fp, unsigned bits, struct bifrost_regs *srcs, struct bifrost_regs *next_regs, unsigned staging_register, unsigned branch_offset, struct bi_constants *consts, bool last)
93{
94    ${body.strip()}
95}
96""")
97
98lut_template_only = Template("""    static const char *${field}[] = {
99        ${", ".join(['"' + x + '"' for x in table])}
100    };
101""")
102
103# Given a lookup table written logically, generate an accessor
104lut_template = Template("""    static const char *${field}_table[] = {
105        ${", ".join(['"' + x + '"' for x in table])}
106    };
107
108    const char *${field} = ${field}_table[_BITS(bits, ${pos}, ${width})];
109""")
110
111# Helpers for decoding follow. pretty_mods applies dot syntax
112
113def pretty_mods(opts, default):
114    return [('.' + (opt or 'reserved') if opt != default else '') for opt in opts]
115
116# Recursively searches for the set of free variables required by an expression
117
118def find_context_keys_expr(expr):
119    if isinstance(expr, list):
120        return set.union(*[find_context_keys_expr(x) for x in expr[1:]])
121    elif expr[0] == '#':
122        return set()
123    else:
124        return set([expr])
125
126def find_context_keys(desc, test):
127    keys = set()
128
129    if len(test) > 0:
130        keys |= find_context_keys_expr(test)
131
132    for i, (_, vals) in enumerate(desc.get('derived', [])):
133        for j, val in enumerate(vals):
134            if val is not None:
135                keys |= find_context_keys_expr(val)
136
137    return keys
138
139# Compiles a logic expression to Python expression, ctx -> { T, F }
140
141EVALUATORS = {
142        'and': ' and ',
143        'or': ' or ',
144        'eq': ' == ',
145        'neq': ' != ',
146}
147
148def compile_derived_inner(expr, keys):
149    if expr == []:
150        return 'True'
151    elif expr is None or expr[0] == 'alias':
152        return 'False'
153    elif isinstance(expr, list):
154        args = [compile_derived_inner(arg, keys) for arg in expr[1:]]
155        return '(' + EVALUATORS[expr[0]].join(args) + ')'
156    elif expr[0] == '#':
157        return "'{}'".format(expr[1:])
158    elif expr == 'ordering':
159        return expr
160    else:
161        return "ctx[{}]".format(keys.index(expr))
162
163def compile_derived(expr, keys):
164    return eval('lambda ctx, ordering: ' + compile_derived_inner(expr, keys))
165
166# Generate all possible combinations of values and evaluate the derived values
167# by bruteforce evaluation to generate a forward mapping (values -> deriveds)
168
169def evaluate_forward_derived(vals, ctx, ordering):
170    for j, expr in enumerate(vals):
171        if expr(ctx, ordering):
172            return j
173
174    return None
175
176def evaluate_forward(keys, derivf, testf, ctx, ordering):
177    if not testf(ctx, ordering):
178        return None
179
180    deriv = []
181
182    for vals in derivf:
183        evaled = evaluate_forward_derived(vals, ctx, ordering)
184
185        if evaled is None:
186            return None
187
188        deriv.append(evaled)
189
190    return deriv
191
192def evaluate_forwards(keys, derivf, testf, mod_vals, ordered):
193    orderings = ["lt", "gt"] if ordered else [None]
194    return [[evaluate_forward(keys, derivf, testf, i, order) for i in itertools.product(*mod_vals)] for order in orderings]
195
196# Invert the forward mapping (values -> deriveds) of finite sets to produce a
197# backwards mapping (deriveds -> values), suitable for disassembly. This is
198# possible since the encoding is unambiguous, so this mapping is a bijection
199# (after reserved/impossible encodings)
200
201def invert_lut(value_size, forward, derived, mod_map, keys, mod_vals):
202    backwards = [None] * (1 << value_size)
203    for (i, deriveds), ctx in zip(enumerate(forward), itertools.product(*mod_vals)):
204        # Skip reserved
205        if deriveds == None:
206            continue
207
208        shift = 0
209        param = 0
210        for j, ((x, width), y) in enumerate(derived):
211            param += (deriveds[j] << shift)
212            shift += width
213
214        assert(param not in backwards)
215        backwards[param] = ctx
216
217    return backwards
218
219# Compute the value of all indirectly specified modifiers by using the
220# backwards mapping (deriveds -> values) as a run-time lookup table.
221
222def build_lut(mnemonic, desc, test):
223    # Construct the system
224    facts = []
225
226    mod_map = {}
227
228    for ((name, pos, width), default, values) in desc.get('modifiers', []):
229        mod_map[name] = (width, values, pos, default)
230
231    derived = desc.get('derived', [])
232
233    # Find the keys and impose an order
234    key_set = find_context_keys(desc, test)
235    ordered = 'ordering' in key_set
236    key_set.discard('ordering')
237    keys = list(key_set)
238
239    # Evaluate the deriveds for every possible state, forming a (state -> deriveds) map
240    testf = compile_derived(test, keys)
241    derivf = [[compile_derived(expr, keys) for expr in v] for (_, v) in derived]
242    mod_vals = [mod_map[k][1] for k in keys]
243    forward = evaluate_forwards(keys, derivf, testf, mod_vals, ordered)
244
245    # Now invert that map to get a (deriveds -> state) map
246    value_size = sum([width for ((x, width), y) in derived])
247    backwards = [invert_lut(value_size, f, derived, mod_map, keys, mod_vals) for f in forward]
248
249    # From that map, we can generate LUTs
250    output = ""
251
252    if ordered:
253        output += "bool ordering = (_BITS(bits, {}, 3) > _BITS(bits, {}, 3));\n".format(desc["srcs"][0][0], desc["srcs"][1][0])
254
255    for j, key in enumerate(keys):
256        # Only generate tables for indirect specifiers
257        if mod_map[key][2] is not None:
258            continue
259
260        idx_parts = []
261        shift = 0
262
263        for ((pos, width), _) in derived:
264            idx_parts.append("(_BITS(bits, {}, {}) << {})".format(pos, width, shift))
265            shift += width
266
267        built_idx = (" | ".join(idx_parts)) if len(idx_parts) > 0 else "0"
268
269        default = mod_map[key][3]
270
271        if ordered:
272            for i, order in enumerate(backwards):
273                options = [ctx[j] if ctx is not None and ctx[j] is not None else "reserved" for ctx in order]
274                output += lut_template_only.render(field = key + "_" + str(i), table = pretty_mods(options, default))
275
276            output += "    const char *{} = ordering ? {}_1[{}] : {}_0[{}];\n".format(key, key, built_idx, key, built_idx)
277        else:
278            options = [ctx[j] if ctx is not None and ctx[j] is not None else "reserved" for ctx in backwards[0]]
279            output += lut_template_only.render(field = key + "_table", table = pretty_mods(options, default))
280            output += "    const char *{} = {}_table[{}];\n".format(key, key, built_idx)
281
282    return output
283
284def disasm_mod(mod, skip_mods):
285    if mod[0][0] in skip_mods:
286        return ''
287    else:
288        return '    fputs({}, fp);\n'.format(mod[0][0])
289
290def disasm_op(name, op):
291    (mnemonic, test, desc) = op
292    is_fma = mnemonic[0] == '*'
293
294    # Modifiers may be either direct (pos is not None) or indirect (pos is
295    # None). If direct, we just do the bit lookup. If indirect, we use a LUT.
296
297    body = ""
298    skip_mods = []
299
300    body += build_lut(mnemonic, desc, test)
301
302    for ((mod, pos, width), default, opts) in desc.get('modifiers', []):
303        if pos is not None:
304            body += lut_template.render(field = mod, table = pretty_mods(opts, default), pos = pos, width = width) + "\n"
305
306    # Mnemonic, followed by modifiers
307    body += '    fputs("{}", fp);\n'.format(mnemonic)
308
309    srcs = desc.get('srcs', [])
310
311    for mod in desc.get('modifiers', []):
312        # Skip per-source until next block
313        if mod[0][0][-1] in "0123" and int(mod[0][0][-1]) < len(srcs):
314            continue
315
316        body += disasm_mod(mod, skip_mods)
317
318    body += '    fputs(" ", fp);\n'
319    body += '    bi_disasm_dest_{}(fp, next_regs, last);\n'.format('fma' if is_fma else 'add')
320
321    # Next up, each source. Source modifiers are inserterd here
322
323    for i, (pos, mask) in enumerate(srcs):
324        body += '    fputs(", ", fp);\n'
325        body += '    dump_src(fp, _BITS(bits, {}, 3), *srcs, consts, {});\n'.format(pos, "true" if is_fma else "false")
326
327        # Error check if needed
328        if (mask != 0xFF):
329            body += '    if (!({} & (1 << _BITS(bits, {}, 3)))) fputs("(INVALID)", fp);\n'.format(hex(mask), pos, 3)
330
331        # Print modifiers suffixed with this src number (e.g. abs0 for src0)
332        for mod in desc.get('modifiers', []):
333            if mod[0][0][-1] == str(i):
334                body += disasm_mod(mod, skip_mods)
335
336    # And each immediate
337    for (imm, pos, width) in desc.get('immediates', []):
338        body += '    fprintf(fp, ", {}:%u", _BITS(bits, {}, {}));\n'.format(imm, pos, width)
339
340    # Attach a staging register if one is used
341    if desc.get('staging'):
342        body += '    fprintf(fp, ", @r%u", staging_register);\n'
343
344    body += '    fputs("\\n", fp);\n'
345    return disasm_op_template.render(c_name = opname_to_c(name), body = body)
346
347print('#include "util/macros.h"')
348print('#include "disassemble.h"')
349
350states = expand_states(instructions)
351print('#define _BITS(bits, pos, width) (((bits) >> (pos)) & ((1 << (width)) - 1))')
352
353for st in states:
354    print(disasm_op(st, states[st]))
355
356print(decode_op(states, True))
357print(decode_op(states, False))
358