1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Jason Ekstrand (jason@jlekstrand.net)
25
26from __future__ import print_function
27import ast
28from collections import defaultdict
29import itertools
30import struct
31import sys
32import mako.template
33import re
34import traceback
35
36from nir_opcodes import opcodes, type_sizes
37
38# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39nir_search_max_comm_ops = 8
40
41# These opcodes are only employed by nir_search.  This provides a mapping from
42# opcode to destination type.
43conv_opcode_types = {
44    'i2f' : 'float',
45    'u2f' : 'float',
46    'f2f' : 'float',
47    'f2u' : 'uint',
48    'f2i' : 'int',
49    'u2u' : 'uint',
50    'i2i' : 'int',
51    'b2f' : 'float',
52    'b2i' : 'int',
53    'i2b' : 'bool',
54    'f2b' : 'bool',
55}
56
57def get_c_opcode(op):
58      if op in conv_opcode_types:
59         return 'nir_search_op_' + op
60      else:
61         return 'nir_op_' + op
62
63
64if sys.version_info < (3, 0):
65    integer_types = (int, long)
66    string_type = unicode
67
68else:
69    integer_types = (int, )
70    string_type = str
71
72_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74def type_bits(type_str):
75   m = _type_re.match(type_str)
76   assert m.group('type')
77
78   if m.group('bits') is None:
79      return 0
80   else:
81      return int(m.group('bits'))
82
83# Represents a set of variables, each with a unique id
84class VarSet(object):
85   def __init__(self):
86      self.names = {}
87      self.ids = itertools.count()
88      self.immutable = False;
89
90   def __getitem__(self, name):
91      if name not in self.names:
92         assert not self.immutable, "Unknown replacement variable: " + name
93         self.names[name] = next(self.ids)
94
95      return self.names[name]
96
97   def lock(self):
98      self.immutable = True
99
100class Value(object):
101   @staticmethod
102   def create(val, name_base, varset):
103      if isinstance(val, bytes):
104         val = val.decode('utf-8')
105
106      if isinstance(val, tuple):
107         return Expression(val, name_base, varset)
108      elif isinstance(val, Expression):
109         return val
110      elif isinstance(val, string_type):
111         return Variable(val, name_base, varset)
112      elif isinstance(val, (bool, float) + integer_types):
113         return Constant(val, name_base)
114
115   def __init__(self, val, name, type_str):
116      self.in_val = str(val)
117      self.name = name
118      self.type_str = type_str
119
120   def __str__(self):
121      return self.in_val
122
123   def get_bit_size(self):
124      """Get the physical bit-size that has been chosen for this value, or if
125      there is none, the canonical value which currently represents this
126      bit-size class. Variables will be preferred, i.e. if there are any
127      variables in the equivalence class, the canonical value will be a
128      variable. We do this since we'll need to know which variable each value
129      is equivalent to when constructing the replacement expression. This is
130      the "find" part of the union-find algorithm.
131      """
132      bit_size = self
133
134      while isinstance(bit_size, Value):
135         if bit_size._bit_size is None:
136            break
137         bit_size = bit_size._bit_size
138
139      if bit_size is not self:
140         self._bit_size = bit_size
141      return bit_size
142
143   def set_bit_size(self, other):
144      """Make self.get_bit_size() return what other.get_bit_size() return
145      before calling this, or just "other" if it's a concrete bit-size. This is
146      the "union" part of the union-find algorithm.
147      """
148
149      self_bit_size = self.get_bit_size()
150      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152      if self_bit_size == other_bit_size:
153         return
154
155      self_bit_size._bit_size = other_bit_size
156
157   @property
158   def type_enum(self):
159      return "nir_search_value_" + self.type_str
160
161   @property
162   def c_type(self):
163      return "nir_search_" + self.type_str
164
165   def __c_name(self, cache):
166      if cache is not None and self.name in cache:
167         return cache[self.name]
168      else:
169         return self.name
170
171   def c_value_ptr(self, cache):
172      return "&{0}.value".format(self.__c_name(cache))
173
174   def c_ptr(self, cache):
175      return "&{0}".format(self.__c_name(cache))
176
177   @property
178   def c_bit_size(self):
179      bit_size = self.get_bit_size()
180      if isinstance(bit_size, int):
181         return bit_size
182      elif isinstance(bit_size, Variable):
183         return -bit_size.index - 1
184      else:
185         # If the bit-size class is neither a variable, nor an actual bit-size, then
186         # - If it's in the search expression, we don't need to check anything
187         # - If it's in the replace expression, either it's ambiguous (in which
188         # case we'd reject it), or it equals the bit-size of the search value
189         # We represent these cases with a 0 bit-size.
190         return 0
191
192   __template = mako.template.Template("""{
193   { ${val.type_enum}, ${val.c_bit_size} },
194% if isinstance(val, Constant):
195   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196% elif isinstance(val, Variable):
197   ${val.index}, /* ${val.var_name} */
198   ${'true' if val.is_constant else 'false'},
199   ${val.type() or 'nir_type_invalid' },
200   ${val.cond if val.cond else 'NULL'},
201   ${val.swizzle()},
202% elif isinstance(val, Expression):
203   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
204   ${val.comm_expr_idx}, ${val.comm_exprs},
205   ${val.c_opcode()},
206   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207   ${val.cond if val.cond else 'NULL'},
208% endif
209};""")
210
211   def render(self, cache):
212      struct_init = self.__template.render(val=self, cache=cache,
213                                           Constant=Constant,
214                                           Variable=Variable,
215                                           Expression=Expression)
216      if cache is not None and struct_init in cache:
217         # If it's in the cache, register a name remap in the cache and render
218         # only a comment saying it's been remapped
219         cache[self.name] = cache[struct_init]
220         return "/* {} -> {} in the cache */\n".format(self.name,
221                                                       cache[struct_init])
222      else:
223         if cache is not None:
224            cache[struct_init] = self.name
225         return "static const {} {} = {}\n".format(self.c_type, self.name,
226                                                   struct_init)
227
228_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230class Constant(Value):
231   def __init__(self, val, name):
232      Value.__init__(self, val, name, "constant")
233
234      if isinstance(val, (str)):
235         m = _constant_re.match(val)
236         self.value = ast.literal_eval(m.group('value'))
237         self._bit_size = int(m.group('bits')) if m.group('bits') else None
238      else:
239         self.value = val
240         self._bit_size = None
241
242      if isinstance(self.value, bool):
243         assert self._bit_size is None or self._bit_size == 1
244         self._bit_size = 1
245
246   def hex(self):
247      if isinstance(self.value, (bool)):
248         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249      if isinstance(self.value, integer_types):
250         return hex(self.value)
251      elif isinstance(self.value, float):
252         i = struct.unpack('Q', struct.pack('d', self.value))[0]
253         h = hex(i)
254
255         # On Python 2 this 'L' suffix is automatically added, but not on Python 3
256         # Adding it explicitly makes the generated file identical, regardless
257         # of the Python version running this script.
258         if h[-1] != 'L' and i > sys.maxsize:
259            h += 'L'
260
261         return h
262      else:
263         assert False
264
265   def type(self):
266      if isinstance(self.value, (bool)):
267         return "nir_type_bool"
268      elif isinstance(self.value, integer_types):
269         return "nir_type_int"
270      elif isinstance(self.value, float):
271         return "nir_type_float"
272
273   def equivalent(self, other):
274      """Check that two constants are equivalent.
275
276      This is check is much weaker than equality.  One generally cannot be
277      used in place of the other.  Using this implementation for the __eq__
278      will break BitSizeValidator.
279
280      """
281      if not isinstance(other, type(self)):
282         return False
283
284      return self.value == other.value
285
286# The $ at the end forces there to be an error if any part of the string
287# doesn't match one of the field patterns.
288_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
289                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
290                          r"(?P<cond>\([^\)]+\))?"
291                          r"(?P<swiz>\.[xyzw]+)?"
292                          r"$")
293
294class Variable(Value):
295   def __init__(self, val, name, varset):
296      Value.__init__(self, val, name, "variable")
297
298      m = _var_name_re.match(val)
299      assert m and m.group('name') is not None, \
300            "Malformed variable name \"{}\".".format(val)
301
302      self.var_name = m.group('name')
303
304      # Prevent common cases where someone puts quotes around a literal
305      # constant.  If we want to support names that have numeric or
306      # punctuation characters, we can me the first assertion more flexible.
307      assert self.var_name.isalpha()
308      assert self.var_name != 'True'
309      assert self.var_name != 'False'
310
311      self.is_constant = m.group('const') is not None
312      self.cond = m.group('cond')
313      self.required_type = m.group('type')
314      self._bit_size = int(m.group('bits')) if m.group('bits') else None
315      self.swiz = m.group('swiz')
316
317      if self.required_type == 'bool':
318         if self._bit_size is not None:
319            assert self._bit_size in type_sizes(self.required_type)
320         else:
321            self._bit_size = 1
322
323      if self.required_type is not None:
324         assert self.required_type in ('float', 'bool', 'int', 'uint')
325
326      self.index = varset[self.var_name]
327
328   def type(self):
329      if self.required_type == 'bool':
330         return "nir_type_bool"
331      elif self.required_type in ('int', 'uint'):
332         return "nir_type_int"
333      elif self.required_type == 'float':
334         return "nir_type_float"
335
336   def equivalent(self, other):
337      """Check that two variables are equivalent.
338
339      This is check is much weaker than equality.  One generally cannot be
340      used in place of the other.  Using this implementation for the __eq__
341      will break BitSizeValidator.
342
343      """
344      if not isinstance(other, type(self)):
345         return False
346
347      return self.index == other.index
348
349   def swizzle(self):
350      if self.swiz is not None:
351         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
352                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
353                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
354                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
355                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
356         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
357      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
358
359_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
360                        r"(?P<cond>\([^\)]+\))?")
361
362class Expression(Value):
363   def __init__(self, expr, name_base, varset):
364      Value.__init__(self, expr, name_base, "expression")
365      assert isinstance(expr, tuple)
366
367      m = _opcode_re.match(expr[0])
368      assert m and m.group('opcode') is not None
369
370      self.opcode = m.group('opcode')
371      self._bit_size = int(m.group('bits')) if m.group('bits') else None
372      self.inexact = m.group('inexact') is not None
373      self.exact = m.group('exact') is not None
374      self.cond = m.group('cond')
375
376      assert not self.inexact or not self.exact, \
377            'Expression cannot be both exact and inexact.'
378
379      # "many-comm-expr" isn't really a condition.  It's notification to the
380      # generator that this pattern is known to have too many commutative
381      # expressions, and an error should not be generated for this case.
382      self.many_commutative_expressions = False
383      if self.cond and self.cond.find("many-comm-expr") >= 0:
384         # Split the condition into a comma-separated list.  Remove
385         # "many-comm-expr".  If there is anything left, put it back together.
386         c = self.cond[1:-1].split(",")
387         c.remove("many-comm-expr")
388
389         self.cond = "({})".format(",".join(c)) if c else None
390         self.many_commutative_expressions = True
391
392      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
393                       for (i, src) in enumerate(expr[1:]) ]
394
395      # nir_search_expression::srcs is hard-coded to 4
396      assert len(self.sources) <= 4
397
398      if self.opcode in conv_opcode_types:
399         assert self._bit_size is None, \
400                'Expression cannot use an unsized conversion opcode with ' \
401                'an explicit size; that\'s silly.'
402
403      self.__index_comm_exprs(0)
404
405   def equivalent(self, other):
406      """Check that two variables are equivalent.
407
408      This is check is much weaker than equality.  One generally cannot be
409      used in place of the other.  Using this implementation for the __eq__
410      will break BitSizeValidator.
411
412      This implementation does not check for equivalence due to commutativity,
413      but it could.
414
415      """
416      if not isinstance(other, type(self)):
417         return False
418
419      if len(self.sources) != len(other.sources):
420         return False
421
422      if self.opcode != other.opcode:
423         return False
424
425      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
426
427   def __index_comm_exprs(self, base_idx):
428      """Recursively count and index commutative expressions
429      """
430      self.comm_exprs = 0
431
432      # A note about the explicit "len(self.sources)" check. The list of
433      # sources comes from user input, and that input might be bad.  Check
434      # that the expected second source exists before accessing it. Without
435      # this check, a unit test that does "('iadd', 'a')" will crash.
436      if self.opcode not in conv_opcode_types and \
437         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
438         len(self.sources) >= 2 and \
439         not self.sources[0].equivalent(self.sources[1]):
440         self.comm_expr_idx = base_idx
441         self.comm_exprs += 1
442      else:
443         self.comm_expr_idx = -1
444
445      for s in self.sources:
446         if isinstance(s, Expression):
447            s.__index_comm_exprs(base_idx + self.comm_exprs)
448            self.comm_exprs += s.comm_exprs
449
450      return self.comm_exprs
451
452   def c_opcode(self):
453      return get_c_opcode(self.opcode)
454
455   def render(self, cache):
456      srcs = "\n".join(src.render(cache) for src in self.sources)
457      return srcs + super(Expression, self).render(cache)
458
459class BitSizeValidator(object):
460   """A class for validating bit sizes of expressions.
461
462   NIR supports multiple bit-sizes on expressions in order to handle things
463   such as fp64.  The source and destination of every ALU operation is
464   assigned a type and that type may or may not specify a bit size.  Sources
465   and destinations whose type does not specify a bit size are considered
466   "unsized" and automatically take on the bit size of the corresponding
467   register or SSA value.  NIR has two simple rules for bit sizes that are
468   validated by nir_validator:
469
470    1) A given SSA def or register has a single bit size that is respected by
471       everything that reads from it or writes to it.
472
473    2) The bit sizes of all unsized inputs/outputs on any given ALU
474       instruction must match.  They need not match the sized inputs or
475       outputs but they must match each other.
476
477   In order to keep nir_algebraic relatively simple and easy-to-use,
478   nir_search supports a type of bit-size inference based on the two rules
479   above.  This is similar to type inference in many common programming
480   languages.  If, for instance, you are constructing an add operation and you
481   know the second source is 16-bit, then you know that the other source and
482   the destination must also be 16-bit.  There are, however, cases where this
483   inference can be ambiguous or contradictory.  Consider, for instance, the
484   following transformation:
485
486   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
487
488   This transformation can potentially cause a problem because usub_borrow is
489   well-defined for any bit-size of integer.  However, b2i always generates a
490   32-bit result so it could end up replacing a 64-bit expression with one
491   that takes two 64-bit values and produces a 32-bit value.  As another
492   example, consider this expression:
493
494   (('bcsel', a, b, 0), ('iand', a, b))
495
496   In this case, in the search expression a must be 32-bit but b can
497   potentially have any bit size.  If we had a 64-bit b value, we would end up
498   trying to and a 32-bit value with a 64-bit value which would be invalid
499
500   This class solves that problem by providing a validation layer that proves
501   that a given search-and-replace operation is 100% well-defined before we
502   generate any code.  This ensures that bugs are caught at compile time
503   rather than at run time.
504
505   Each value maintains a "bit-size class", which is either an actual bit size
506   or an equivalence class with other values that must have the same bit size.
507   The validator works by combining bit-size classes with each other according
508   to the NIR rules outlined above, checking that there are no inconsistencies.
509   When doing this for the replacement expression, we make sure to never change
510   the equivalence class of any of the search values. We could make the example
511   transforms above work by doing some extra run-time checking of the search
512   expression, but we make the user specify those constraints themselves, to
513   avoid any surprises. Since the replacement bitsizes can only be connected to
514   the source bitsize via variables (variables must have the same bitsize in
515   the source and replacment expressions) or the roots of the expression (the
516   replacement expression must produce the same bit size as the search
517   expression), we prevent merging a variable with anything when processing the
518   replacement expression, or specializing the search bitsize
519   with anything. The former prevents
520
521   (('bcsel', a, b, 0), ('iand', a, b))
522
523   from being allowed, since we'd have to merge the bitsizes for a and b due to
524   the 'iand', while the latter prevents
525
526   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
527
528   from being allowed, since the search expression has the bit size of a and b,
529   which can't be specialized to 32 which is the bitsize of the replace
530   expression. It also prevents something like:
531
532   (('b2i', ('i2b', a)), ('ineq', a, 0))
533
534   since the bitsize of 'b2i', which can be anything, can't be specialized to
535   the bitsize of a.
536
537   After doing all this, we check that every subexpression of the replacement
538   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
539   of the search expresssion, since those are the things that are known when
540   constructing the replacement expresssion. Finally, we record the bitsize
541   needed in nir_search_value so that we know what to do when building the
542   replacement expression.
543   """
544
545   def __init__(self, varset):
546      self._var_classes = [None] * len(varset.names)
547
548   def compare_bitsizes(self, a, b):
549      """Determines which bitsize class is a specialization of the other, or
550      whether neither is. When we merge two different bitsizes, the
551      less-specialized bitsize always points to the more-specialized one, so
552      that calling get_bit_size() always gets you the most specialized bitsize.
553      The specialization partial order is given by:
554      - Physical bitsizes are always the most specialized, and a different
555        bitsize can never specialize another.
556      - In the search expression, variables can always be specialized to each
557        other and to physical bitsizes. In the replace expression, we disallow
558        this to avoid adding extra constraints to the search expression that
559        the user didn't specify.
560      - Expressions and constants without a bitsize can always be specialized to
561        each other and variables, but not the other way around.
562
563        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
564        and None if they are not comparable (neither a <= b nor b <= a).
565      """
566      if isinstance(a, int):
567         if isinstance(b, int):
568            return 0 if a == b else None
569         elif isinstance(b, Variable):
570            return -1 if self.is_search else None
571         else:
572            return -1
573      elif isinstance(a, Variable):
574         if isinstance(b, int):
575            return 1 if self.is_search else None
576         elif isinstance(b, Variable):
577            return 0 if self.is_search or a.index == b.index else None
578         else:
579            return -1
580      else:
581         if isinstance(b, int):
582            return 1
583         elif isinstance(b, Variable):
584            return 1
585         else:
586            return 0
587
588   def unify_bit_size(self, a, b, error_msg):
589      """Record that a must have the same bit-size as b. If both
590      have been assigned conflicting physical bit-sizes, call "error_msg" with
591      the bit-sizes of self and other to get a message and raise an error.
592      In the replace expression, disallow merging variables with other
593      variables and physical bit-sizes as well.
594      """
595      a_bit_size = a.get_bit_size()
596      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
597
598      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
599
600      assert cmp_result is not None, \
601         error_msg(a_bit_size, b_bit_size)
602
603      if cmp_result < 0:
604         b_bit_size.set_bit_size(a)
605      elif not isinstance(a_bit_size, int):
606         a_bit_size.set_bit_size(b)
607
608   def merge_variables(self, val):
609      """Perform the first part of type inference by merging all the different
610      uses of the same variable. We always do this as if we're in the search
611      expression, even if we're actually not, since otherwise we'd get errors
612      if the search expression specified some constraint but the replace
613      expression didn't, because we'd be merging a variable and a constant.
614      """
615      if isinstance(val, Variable):
616         if self._var_classes[val.index] is None:
617            self._var_classes[val.index] = val
618         else:
619            other = self._var_classes[val.index]
620            self.unify_bit_size(other, val,
621                  lambda other_bit_size, bit_size:
622                     'Variable {} has conflicting bit size requirements: ' \
623                     'it must have bit size {} and {}'.format(
624                        val.var_name, other_bit_size, bit_size))
625      elif isinstance(val, Expression):
626         for src in val.sources:
627            self.merge_variables(src)
628
629   def validate_value(self, val):
630      """Validate the an expression by performing classic Hindley-Milner
631      type inference on bitsizes. This will detect if there are any conflicting
632      requirements, and unify variables so that we know which variables must
633      have the same bitsize. If we're operating on the replace expression, we
634      will refuse to merge different variables together or merge a variable
635      with a constant, in order to prevent surprises due to rules unexpectedly
636      not matching at runtime.
637      """
638      if not isinstance(val, Expression):
639         return
640
641      # Generic conversion ops are special in that they have a single unsized
642      # source and an unsized destination and the two don't have to match.
643      # This means there's no validation or unioning to do here besides the
644      # len(val.sources) check.
645      if val.opcode in conv_opcode_types:
646         assert len(val.sources) == 1, \
647            "Expression {} has {} sources, expected 1".format(
648               val, len(val.sources))
649         self.validate_value(val.sources[0])
650         return
651
652      nir_op = opcodes[val.opcode]
653      assert len(val.sources) == nir_op.num_inputs, \
654         "Expression {} has {} sources, expected {}".format(
655            val, len(val.sources), nir_op.num_inputs)
656
657      for src in val.sources:
658         self.validate_value(src)
659
660      dst_type_bits = type_bits(nir_op.output_type)
661
662      # First, unify all the sources. That way, an error coming up because two
663      # sources have an incompatible bit-size won't produce an error message
664      # involving the destination.
665      first_unsized_src = None
666      for src_type, src in zip(nir_op.input_types, val.sources):
667         src_type_bits = type_bits(src_type)
668         if src_type_bits == 0:
669            if first_unsized_src is None:
670               first_unsized_src = src
671               continue
672
673            if self.is_search:
674               self.unify_bit_size(first_unsized_src, src,
675                  lambda first_unsized_src_bit_size, src_bit_size:
676                     'Source {} of {} must have bit size {}, while source {} ' \
677                     'must have incompatible bit size {}'.format(
678                        first_unsized_src, val, first_unsized_src_bit_size,
679                        src, src_bit_size))
680            else:
681               self.unify_bit_size(first_unsized_src, src,
682                  lambda first_unsized_src_bit_size, src_bit_size:
683                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
684                     'of {} may not have the same bit size when building the ' \
685                     'replacement expression.'.format(
686                        first_unsized_src, first_unsized_src_bit_size, src,
687                        src_bit_size, val))
688         else:
689            if self.is_search:
690               self.unify_bit_size(src, src_type_bits,
691                  lambda src_bit_size, unused:
692                     '{} must have {} bits, but as a source of nir_op_{} '\
693                     'it must have {} bits'.format(
694                        src, src_bit_size, nir_op.name, src_type_bits))
695            else:
696               self.unify_bit_size(src, src_type_bits,
697                  lambda src_bit_size, unused:
698                     '{} has the bit size of {}, but as a source of ' \
699                     'nir_op_{} it must have {} bits, which may not be the ' \
700                     'same'.format(
701                        src, src_bit_size, nir_op.name, src_type_bits))
702
703      if dst_type_bits == 0:
704         if first_unsized_src is not None:
705            if self.is_search:
706               self.unify_bit_size(val, first_unsized_src,
707                  lambda val_bit_size, src_bit_size:
708                     '{} must have the bit size of {}, while its source {} ' \
709                     'must have incompatible bit size {}'.format(
710                        val, val_bit_size, first_unsized_src, src_bit_size))
711            else:
712               self.unify_bit_size(val, first_unsized_src,
713                  lambda val_bit_size, src_bit_size:
714                     '{} must have {} bits, but its source {} ' \
715                     '(bit size of {}) may not have that bit size ' \
716                     'when building the replacement.'.format(
717                        val, val_bit_size, first_unsized_src, src_bit_size))
718      else:
719         self.unify_bit_size(val, dst_type_bits,
720            lambda dst_bit_size, unused:
721               '{} must have {} bits, but as a destination of nir_op_{} ' \
722               'it must have {} bits'.format(
723                  val, dst_bit_size, nir_op.name, dst_type_bits))
724
725   def validate_replace(self, val, search):
726      bit_size = val.get_bit_size()
727      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
728            bit_size == search.get_bit_size(), \
729            'Ambiguous bit size for replacement value {}: ' \
730            'it cannot be deduced from a variable, a fixed bit size ' \
731            'somewhere, or the search expression.'.format(val)
732
733      if isinstance(val, Expression):
734         for src in val.sources:
735            self.validate_replace(src, search)
736
737   def validate(self, search, replace):
738      self.is_search = True
739      self.merge_variables(search)
740      self.merge_variables(replace)
741      self.validate_value(search)
742
743      self.is_search = False
744      self.validate_value(replace)
745
746      # Check that search is always more specialized than replace. Note that
747      # we're doing this in replace mode, disallowing merging variables.
748      search_bit_size = search.get_bit_size()
749      replace_bit_size = replace.get_bit_size()
750      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
751
752      assert cmp_result is not None and cmp_result <= 0, \
753         'The search expression bit size {} and replace expression ' \
754         'bit size {} may not be the same'.format(
755               search_bit_size, replace_bit_size)
756
757      replace.set_bit_size(search)
758
759      self.validate_replace(replace, search)
760
761_optimization_ids = itertools.count()
762
763condition_list = ['true']
764
765class SearchAndReplace(object):
766   def __init__(self, transform):
767      self.id = next(_optimization_ids)
768
769      search = transform[0]
770      replace = transform[1]
771      if len(transform) > 2:
772         self.condition = transform[2]
773      else:
774         self.condition = 'true'
775
776      if self.condition not in condition_list:
777         condition_list.append(self.condition)
778      self.condition_index = condition_list.index(self.condition)
779
780      varset = VarSet()
781      if isinstance(search, Expression):
782         self.search = search
783      else:
784         self.search = Expression(search, "search{0}".format(self.id), varset)
785
786      varset.lock()
787
788      if isinstance(replace, Value):
789         self.replace = replace
790      else:
791         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
792
793      BitSizeValidator(varset).validate(self.search, self.replace)
794
795class TreeAutomaton(object):
796   """This class calculates a bottom-up tree automaton to quickly search for
797   the left-hand sides of tranforms. Tree automatons are a generalization of
798   classical NFA's and DFA's, where the transition function determines the
799   state of the parent node based on the state of its children. We construct a
800   deterministic automaton to match patterns, using a similar algorithm to the
801   classical NFA to DFA construction. At the moment, it only matches opcodes
802   and constants (without checking the actual value), leaving more detailed
803   checking to the search function which actually checks the leaves. The
804   automaton acts as a quick filter for the search function, requiring only n
805   + 1 table lookups for each n-source operation. The implementation is based
806   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
807   In the language of that reference, this is a frontier-to-root deterministic
808   automaton using only symbol filtering. The filtering is crucial to reduce
809   both the time taken to generate the tables and the size of the tables.
810   """
811   def __init__(self, transforms):
812      self.patterns = [t.search for t in transforms]
813      self._compute_items()
814      self._build_table()
815      #print('num items: {}'.format(len(set(self.items.values()))))
816      #print('num states: {}'.format(len(self.states)))
817      #for state, patterns in zip(self.states, self.patterns):
818      #   print('{}: num patterns: {}'.format(state, len(patterns)))
819
820   class IndexMap(object):
821      """An indexed list of objects, where one can either lookup an object by
822      index or find the index associated to an object quickly using a hash
823      table. Compared to a list, it has a constant time index(). Compared to a
824      set, it provides a stable iteration order.
825      """
826      def __init__(self, iterable=()):
827         self.objects = []
828         self.map = {}
829         for obj in iterable:
830            self.add(obj)
831
832      def __getitem__(self, i):
833         return self.objects[i]
834
835      def __contains__(self, obj):
836         return obj in self.map
837
838      def __len__(self):
839         return len(self.objects)
840
841      def __iter__(self):
842         return iter(self.objects)
843
844      def clear(self):
845         self.objects = []
846         self.map.clear()
847
848      def index(self, obj):
849         return self.map[obj]
850
851      def add(self, obj):
852         if obj in self.map:
853            return self.map[obj]
854         else:
855            index = len(self.objects)
856            self.objects.append(obj)
857            self.map[obj] = index
858            return index
859
860      def __repr__(self):
861         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
862
863   class Item(object):
864      """This represents an "item" in the language of "Tree Automatons." This
865      is just a subtree of some pattern, which represents a potential partial
866      match at runtime. We deduplicate them, so that identical subtrees of
867      different patterns share the same object, and store some extra
868      information needed for the main algorithm as well.
869      """
870      def __init__(self, opcode, children):
871         self.opcode = opcode
872         self.children = children
873         # These are the indices of patterns for which this item is the root node.
874         self.patterns = []
875         # This the set of opcodes for parents of this item. Used to speed up
876         # filtering.
877         self.parent_ops = set()
878
879      def __str__(self):
880         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
881
882      def __repr__(self):
883         return str(self)
884
885   def _compute_items(self):
886      """Build a set of all possible items, deduplicating them."""
887      # This is a map from (opcode, sources) to item.
888      self.items = {}
889
890      # The set of all opcodes used by the patterns. Used later to avoid
891      # building and emitting all the tables for opcodes that aren't used.
892      self.opcodes = self.IndexMap()
893
894      def get_item(opcode, children, pattern=None):
895         commutative = len(children) >= 2 \
896               and "2src_commutative" in opcodes[opcode].algebraic_properties
897         item = self.items.setdefault((opcode, children),
898                                      self.Item(opcode, children))
899         if commutative:
900            self.items[opcode, (children[1], children[0]) + children[2:]] = item
901         if pattern is not None:
902            item.patterns.append(pattern)
903         return item
904
905      self.wildcard = get_item("__wildcard", ())
906      self.const = get_item("__const", ())
907
908      def process_subpattern(src, pattern=None):
909         if isinstance(src, Constant):
910            # Note: we throw away the actual constant value!
911            return self.const
912         elif isinstance(src, Variable):
913            if src.is_constant:
914               return self.const
915            else:
916               # Note: we throw away which variable it is here! This special
917               # item is equivalent to nu in "Tree Automatons."
918               return self.wildcard
919         else:
920            assert isinstance(src, Expression)
921            opcode = src.opcode
922            stripped = opcode.rstrip('0123456789')
923            if stripped in conv_opcode_types:
924               # Matches that use conversion opcodes with a specific type,
925               # like f2b1, are tricky.  Either we construct the automaton to
926               # match specific NIR opcodes like nir_op_f2b1, in which case we
927               # need to create separate items for each possible NIR opcode
928               # for patterns that have a generic opcode like f2b, or we
929               # construct it to match the search opcode, in which case we
930               # need to map f2b1 to f2b when constructing the automaton. Here
931               # we do the latter.
932               opcode = stripped
933            self.opcodes.add(opcode)
934            children = tuple(process_subpattern(c) for c in src.sources)
935            item = get_item(opcode, children, pattern)
936            for i, child in enumerate(children):
937               child.parent_ops.add(opcode)
938            return item
939
940      for i, pattern in enumerate(self.patterns):
941         process_subpattern(pattern, i)
942
943   def _build_table(self):
944      """This is the core algorithm which builds up the transition table. It
945      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
946      Comp_a and Filt_{a,i} using integers to identify match sets." It
947      simultaneously builds up a list of all possible "match sets" or
948      "states", where each match set represents the set of Item's that match a
949      given instruction, and builds up the transition table between states.
950      """
951      # Map from opcode + filtered state indices to transitioned state.
952      self.table = defaultdict(dict)
953      # Bijection from state to index. q in the original algorithm is
954      # len(self.states)
955      self.states = self.IndexMap()
956      # List of pattern matches for each state index.
957      self.state_patterns = []
958      # Map from state index to filtered state index for each opcode.
959      self.filter = defaultdict(list)
960      # Bijections from filtered state to filtered state index for each
961      # opcode, called the "representor sets" in the original algorithm.
962      # q_{a,j} in the original algorithm is len(self.rep[op]).
963      self.rep = defaultdict(self.IndexMap)
964
965      # Everything in self.states with a index at least worklist_index is part
966      # of the worklist of newly created states. There is also a worklist of
967      # newly fitered states for each opcode, for which worklist_indices
968      # serves a similar purpose. worklist_index corresponds to p in the
969      # original algorithm, while worklist_indices is p_{a,j} (although since
970      # we only filter by opcode/symbol, it's really just p_a).
971      self.worklist_index = 0
972      worklist_indices = defaultdict(lambda: 0)
973
974      # This is the set of opcodes for which the filtered worklist is non-empty.
975      # It's used to avoid scanning opcodes for which there is nothing to
976      # process when building the transition table. It corresponds to new_a in
977      # the original algorithm.
978      new_opcodes = self.IndexMap()
979
980      # Process states on the global worklist, filtering them for each opcode,
981      # updating the filter tables, and updating the filtered worklists if any
982      # new filtered states are found. Similar to ComputeRepresenterSets() in
983      # the original algorithm, although that only processes a single state.
984      def process_new_states():
985         while self.worklist_index < len(self.states):
986            state = self.states[self.worklist_index]
987
988            # Calculate pattern matches for this state. Each pattern is
989            # assigned to a unique item, so we don't have to worry about
990            # deduplicating them here. However, we do have to sort them so
991            # that they're visited at runtime in the order they're specified
992            # in the source.
993            patterns = list(sorted(p for item in state for p in item.patterns))
994            assert len(self.state_patterns) == self.worklist_index
995            self.state_patterns.append(patterns)
996
997            # calculate filter table for this state, and update filtered
998            # worklists.
999            for op in self.opcodes:
1000               filt = self.filter[op]
1001               rep = self.rep[op]
1002               filtered = frozenset(item for item in state if \
1003                  op in item.parent_ops)
1004               if filtered in rep:
1005                  rep_index = rep.index(filtered)
1006               else:
1007                  rep_index = rep.add(filtered)
1008                  new_opcodes.add(op)
1009               assert len(filt) == self.worklist_index
1010               filt.append(rep_index)
1011            self.worklist_index += 1
1012
1013      # There are two start states: one which can only match as a wildcard,
1014      # and one which can match as a wildcard or constant. These will be the
1015      # states of intrinsics/other instructions and load_const instructions,
1016      # respectively. The indices of these must match the definitions of
1017      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1018      # initialize things correctly.
1019      self.states.add(frozenset((self.wildcard,)))
1020      self.states.add(frozenset((self.const,self.wildcard)))
1021      process_new_states()
1022
1023      while len(new_opcodes) > 0:
1024         for op in new_opcodes:
1025            rep = self.rep[op]
1026            table = self.table[op]
1027            op_worklist_index = worklist_indices[op]
1028            if op in conv_opcode_types:
1029               num_srcs = 1
1030            else:
1031               num_srcs = opcodes[op].num_inputs
1032
1033            # Iterate over all possible source combinations where at least one
1034            # is on the worklist.
1035            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1036               if all(src_idx < op_worklist_index for src_idx in src_indices):
1037                  continue
1038
1039               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1040
1041               # Try all possible pairings of source items and add the
1042               # corresponding parent items. This is Comp_a from the paper.
1043               parent = set(self.items[op, item_srcs] for item_srcs in
1044                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1045
1046               # We could always start matching something else with a
1047               # wildcard. This is Cl from the paper.
1048               parent.add(self.wildcard)
1049
1050               table[src_indices] = self.states.add(frozenset(parent))
1051            worklist_indices[op] = len(rep)
1052         new_opcodes.clear()
1053         process_new_states()
1054
1055_algebraic_pass_template = mako.template.Template("""
1056#include "nir.h"
1057#include "nir_builder.h"
1058#include "nir_search.h"
1059#include "nir_search_helpers.h"
1060
1061/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1062 * transforms:
1063% for xform in xforms:
1064 *    ${xform.search} => ${xform.replace}
1065% endfor
1066 */
1067
1068<% cache = {} %>
1069% for xform in xforms:
1070   ${xform.search.render(cache)}
1071   ${xform.replace.render(cache)}
1072% endfor
1073
1074% for state_id, state_xforms in enumerate(automaton.state_patterns):
1075% if state_xforms: # avoid emitting a 0-length array for MSVC
1076static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1077% for i in state_xforms:
1078  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1079% endfor
1080};
1081% endif
1082% endfor
1083
1084static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1085% for op in automaton.opcodes:
1086   [${get_c_opcode(op)}] = {
1087      .filter = (uint16_t []) {
1088      % for e in automaton.filter[op]:
1089         ${e},
1090      % endfor
1091      },
1092      <%
1093        num_filtered = len(automaton.rep[op])
1094      %>
1095      .num_filtered_states = ${num_filtered},
1096      .table = (uint16_t []) {
1097      <%
1098        num_srcs = len(next(iter(automaton.table[op])))
1099      %>
1100      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1101         ${automaton.table[op][indices]},
1102      % endfor
1103      },
1104   },
1105% endfor
1106};
1107
1108const struct transform *${pass_name}_transforms[] = {
1109% for i in range(len(automaton.state_patterns)):
1110   % if automaton.state_patterns[i]:
1111   ${pass_name}_state${i}_xforms,
1112   % else:
1113   NULL,
1114   % endif
1115% endfor
1116};
1117
1118const uint16_t ${pass_name}_transform_counts[] = {
1119% for i in range(len(automaton.state_patterns)):
1120   % if automaton.state_patterns[i]:
1121   (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1122   % else:
1123   0,
1124   % endif
1125% endfor
1126};
1127
1128bool
1129${pass_name}(nir_shader *shader)
1130{
1131   bool progress = false;
1132   bool condition_flags[${len(condition_list)}];
1133   const nir_shader_compiler_options *options = shader->options;
1134   const shader_info *info = &shader->info;
1135   (void) options;
1136   (void) info;
1137
1138   % for index, condition in enumerate(condition_list):
1139   condition_flags[${index}] = ${condition};
1140   % endfor
1141
1142   nir_foreach_function(function, shader) {
1143      if (function->impl) {
1144         progress |= nir_algebraic_impl(function->impl, condition_flags,
1145                                        ${pass_name}_transforms,
1146                                        ${pass_name}_transform_counts,
1147                                        ${pass_name}_table);
1148      }
1149   }
1150
1151   return progress;
1152}
1153""")
1154
1155
1156class AlgebraicPass(object):
1157   def __init__(self, pass_name, transforms):
1158      self.xforms = []
1159      self.opcode_xforms = defaultdict(lambda : [])
1160      self.pass_name = pass_name
1161
1162      error = False
1163
1164      for xform in transforms:
1165         if not isinstance(xform, SearchAndReplace):
1166            try:
1167               xform = SearchAndReplace(xform)
1168            except:
1169               print("Failed to parse transformation:", file=sys.stderr)
1170               print("  " + str(xform), file=sys.stderr)
1171               traceback.print_exc(file=sys.stderr)
1172               print('', file=sys.stderr)
1173               error = True
1174               continue
1175
1176         self.xforms.append(xform)
1177         if xform.search.opcode in conv_opcode_types:
1178            dst_type = conv_opcode_types[xform.search.opcode]
1179            for size in type_sizes(dst_type):
1180               sized_opcode = xform.search.opcode + str(size)
1181               self.opcode_xforms[sized_opcode].append(xform)
1182         else:
1183            self.opcode_xforms[xform.search.opcode].append(xform)
1184
1185         # Check to make sure the search pattern does not unexpectedly contain
1186         # more commutative expressions than match_expression (nir_search.c)
1187         # can handle.
1188         comm_exprs = xform.search.comm_exprs
1189
1190         if xform.search.many_commutative_expressions:
1191            if comm_exprs <= nir_search_max_comm_ops:
1192               print("Transform expected to have too many commutative " \
1193                     "expression but did not " \
1194                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1195                     file=sys.stderr)
1196               print("  " + str(xform), file=sys.stderr)
1197               traceback.print_exc(file=sys.stderr)
1198               print('', file=sys.stderr)
1199               error = True
1200         else:
1201            if comm_exprs > nir_search_max_comm_ops:
1202               print("Transformation with too many commutative expressions " \
1203                     "({} > {}).  Modify pattern or annotate with " \
1204                     "\"many-comm-expr\".".format(comm_exprs,
1205                                                  nir_search_max_comm_ops),
1206                     file=sys.stderr)
1207               print("  " + str(xform.search), file=sys.stderr)
1208               print("{}".format(xform.search.cond), file=sys.stderr)
1209               error = True
1210
1211      self.automaton = TreeAutomaton(self.xforms)
1212
1213      if error:
1214         sys.exit(1)
1215
1216
1217   def render(self):
1218      return _algebraic_pass_template.render(pass_name=self.pass_name,
1219                                             xforms=self.xforms,
1220                                             opcode_xforms=self.opcode_xforms,
1221                                             condition_list=condition_list,
1222                                             automaton=self.automaton,
1223                                             get_c_opcode=get_c_opcode,
1224                                             itertools=itertools)
1225