1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Jason Ekstrand (jason@jlekstrand.net)
25
26from __future__ import print_function
27import ast
28import itertools
29import struct
30import sys
31import mako.template
32import re
33import traceback
34
35from nir_opcodes import opcodes
36
37_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
38
39def type_bits(type_str):
40   m = _type_re.match(type_str)
41   assert m.group('type')
42
43   if m.group('bits') is None:
44      return 0
45   else:
46      return int(m.group('bits'))
47
48# Represents a set of variables, each with a unique id
49class VarSet(object):
50   def __init__(self):
51      self.names = {}
52      self.ids = itertools.count()
53      self.immutable = False;
54
55   def __getitem__(self, name):
56      if name not in self.names:
57         assert not self.immutable, "Unknown replacement variable: " + name
58         self.names[name] = self.ids.next()
59
60      return self.names[name]
61
62   def lock(self):
63      self.immutable = True
64
65class Value(object):
66   @staticmethod
67   def create(val, name_base, varset):
68      if isinstance(val, tuple):
69         return Expression(val, name_base, varset)
70      elif isinstance(val, Expression):
71         return val
72      elif isinstance(val, (str, unicode)):
73         return Variable(val, name_base, varset)
74      elif isinstance(val, (bool, int, long, float)):
75         return Constant(val, name_base)
76
77   __template = mako.template.Template("""
78static const ${val.c_type} ${val.name} = {
79   { ${val.type_enum}, ${val.bit_size} },
80% if isinstance(val, Constant):
81   ${val.type()}, { ${hex(val)} /* ${val.value} */ },
82% elif isinstance(val, Variable):
83   ${val.index}, /* ${val.var_name} */
84   ${'true' if val.is_constant else 'false'},
85   ${val.type() or 'nir_type_invalid' },
86   ${val.cond if val.cond else 'NULL'},
87% elif isinstance(val, Expression):
88   ${'true' if val.inexact else 'false'},
89   nir_op_${val.opcode},
90   { ${', '.join(src.c_ptr for src in val.sources)} },
91   ${val.cond if val.cond else 'NULL'},
92% endif
93};""")
94
95   def __init__(self, name, type_str):
96      self.name = name
97      self.type_str = type_str
98
99   @property
100   def type_enum(self):
101      return "nir_search_value_" + self.type_str
102
103   @property
104   def c_type(self):
105      return "nir_search_" + self.type_str
106
107   @property
108   def c_ptr(self):
109      return "&{0}.value".format(self.name)
110
111   def render(self):
112      return self.__template.render(val=self,
113                                    Constant=Constant,
114                                    Variable=Variable,
115                                    Expression=Expression)
116
117_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
118
119class Constant(Value):
120   def __init__(self, val, name):
121      Value.__init__(self, name, "constant")
122
123      if isinstance(val, (str)):
124         m = _constant_re.match(val)
125         self.value = ast.literal_eval(m.group('value'))
126         self.bit_size = int(m.group('bits')) if m.group('bits') else 0
127      else:
128         self.value = val
129         self.bit_size = 0
130
131      if isinstance(self.value, bool):
132         assert self.bit_size == 0 or self.bit_size == 32
133         self.bit_size = 32
134
135   def __hex__(self):
136      if isinstance(self.value, (bool)):
137         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
138      if isinstance(self.value, (int, long)):
139         return hex(self.value)
140      elif isinstance(self.value, float):
141         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
142      else:
143         assert False
144
145   def type(self):
146      if isinstance(self.value, (bool)):
147         return "nir_type_bool32"
148      elif isinstance(self.value, (int, long)):
149         return "nir_type_int"
150      elif isinstance(self.value, float):
151         return "nir_type_float"
152
153_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
154                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
155                          r"(?P<cond>\([^\)]+\))?")
156
157class Variable(Value):
158   def __init__(self, val, name, varset):
159      Value.__init__(self, name, "variable")
160
161      m = _var_name_re.match(val)
162      assert m and m.group('name') is not None
163
164      self.var_name = m.group('name')
165      self.is_constant = m.group('const') is not None
166      self.cond = m.group('cond')
167      self.required_type = m.group('type')
168      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
169
170      if self.required_type == 'bool':
171         assert self.bit_size == 0 or self.bit_size == 32
172         self.bit_size = 32
173
174      if self.required_type is not None:
175         assert self.required_type in ('float', 'bool', 'int', 'uint')
176
177      self.index = varset[self.var_name]
178
179   def type(self):
180      if self.required_type == 'bool':
181         return "nir_type_bool32"
182      elif self.required_type in ('int', 'uint'):
183         return "nir_type_int"
184      elif self.required_type == 'float':
185         return "nir_type_float"
186
187_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
188                        r"(?P<cond>\([^\)]+\))?")
189
190class Expression(Value):
191   def __init__(self, expr, name_base, varset):
192      Value.__init__(self, name_base, "expression")
193      assert isinstance(expr, tuple)
194
195      m = _opcode_re.match(expr[0])
196      assert m and m.group('opcode') is not None
197
198      self.opcode = m.group('opcode')
199      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
200      self.inexact = m.group('inexact') is not None
201      self.cond = m.group('cond')
202      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
203                       for (i, src) in enumerate(expr[1:]) ]
204
205   def render(self):
206      srcs = "\n".join(src.render() for src in self.sources)
207      return srcs + super(Expression, self).render()
208
209class IntEquivalenceRelation(object):
210   """A class representing an equivalence relation on integers.
211
212   Each integer has a canonical form which is the maximum integer to which it
213   is equivalent.  Two integers are equivalent precisely when they have the
214   same canonical form.
215
216   The convention of maximum is explicitly chosen to make using it in
217   BitSizeValidator easier because it means that an actual bit_size (if any)
218   will always be the canonical form.
219   """
220   def __init__(self):
221      self._remap = {}
222
223   def get_canonical(self, x):
224      """Get the canonical integer corresponding to x."""
225      if x in self._remap:
226         return self.get_canonical(self._remap[x])
227      else:
228         return x
229
230   def add_equiv(self, a, b):
231      """Add an equivalence and return the canonical form."""
232      c = max(self.get_canonical(a), self.get_canonical(b))
233      if a != c:
234         assert a < c
235         self._remap[a] = c
236
237      if b != c:
238         assert b < c
239         self._remap[b] = c
240
241      return c
242
243class BitSizeValidator(object):
244   """A class for validating bit sizes of expressions.
245
246   NIR supports multiple bit-sizes on expressions in order to handle things
247   such as fp64.  The source and destination of every ALU operation is
248   assigned a type and that type may or may not specify a bit size.  Sources
249   and destinations whose type does not specify a bit size are considered
250   "unsized" and automatically take on the bit size of the corresponding
251   register or SSA value.  NIR has two simple rules for bit sizes that are
252   validated by nir_validator:
253
254    1) A given SSA def or register has a single bit size that is respected by
255       everything that reads from it or writes to it.
256
257    2) The bit sizes of all unsized inputs/outputs on any given ALU
258       instruction must match.  They need not match the sized inputs or
259       outputs but they must match each other.
260
261   In order to keep nir_algebraic relatively simple and easy-to-use,
262   nir_search supports a type of bit-size inference based on the two rules
263   above.  This is similar to type inference in many common programming
264   languages.  If, for instance, you are constructing an add operation and you
265   know the second source is 16-bit, then you know that the other source and
266   the destination must also be 16-bit.  There are, however, cases where this
267   inference can be ambiguous or contradictory.  Consider, for instance, the
268   following transformation:
269
270   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
271
272   This transformation can potentially cause a problem because usub_borrow is
273   well-defined for any bit-size of integer.  However, b2i always generates a
274   32-bit result so it could end up replacing a 64-bit expression with one
275   that takes two 64-bit values and produces a 32-bit value.  As another
276   example, consider this expression:
277
278   (('bcsel', a, b, 0), ('iand', a, b))
279
280   In this case, in the search expression a must be 32-bit but b can
281   potentially have any bit size.  If we had a 64-bit b value, we would end up
282   trying to and a 32-bit value with a 64-bit value which would be invalid
283
284   This class solves that problem by providing a validation layer that proves
285   that a given search-and-replace operation is 100% well-defined before we
286   generate any code.  This ensures that bugs are caught at compile time
287   rather than at run time.
288
289   The basic operation of the validator is very similar to the bitsize_tree in
290   nir_search only a little more subtle.  Instead of simply tracking bit
291   sizes, it tracks "bit classes" where each class is represented by an
292   integer.  A value of 0 means we don't know anything yet, positive values
293   are actual bit-sizes, and negative values are used to track equivalence
294   classes of sizes that must be the same but have yet to receive an actual
295   size.  The first stage uses the bitsize_tree algorithm to assign bit
296   classes to each variable.  If it ever comes across an inconsistency, it
297   assert-fails.  Then the second stage uses that information to prove that
298   the resulting expression can always validly be constructed.
299   """
300
301   def __init__(self, varset):
302      self._num_classes = 0
303      self._var_classes = [0] * len(varset.names)
304      self._class_relation = IntEquivalenceRelation()
305
306   def validate(self, search, replace):
307      dst_class = self._propagate_bit_size_up(search)
308      if dst_class == 0:
309         dst_class = self._new_class()
310      self._propagate_bit_class_down(search, dst_class)
311
312      validate_dst_class = self._validate_bit_class_up(replace)
313      assert validate_dst_class == 0 or validate_dst_class == dst_class
314      self._validate_bit_class_down(replace, dst_class)
315
316   def _new_class(self):
317      self._num_classes += 1
318      return -self._num_classes
319
320   def _set_var_bit_class(self, var_id, bit_class):
321      assert bit_class != 0
322      var_class = self._var_classes[var_id]
323      if var_class == 0:
324         self._var_classes[var_id] = bit_class
325      else:
326         canon_class = self._class_relation.get_canonical(var_class)
327         assert canon_class < 0 or canon_class == bit_class
328         var_class = self._class_relation.add_equiv(var_class, bit_class)
329         self._var_classes[var_id] = var_class
330
331   def _get_var_bit_class(self, var_id):
332      return self._class_relation.get_canonical(self._var_classes[var_id])
333
334   def _propagate_bit_size_up(self, val):
335      if isinstance(val, (Constant, Variable)):
336         return val.bit_size
337
338      elif isinstance(val, Expression):
339         nir_op = opcodes[val.opcode]
340         val.common_size = 0
341         for i in range(nir_op.num_inputs):
342            src_bits = self._propagate_bit_size_up(val.sources[i])
343            if src_bits == 0:
344               continue
345
346            src_type_bits = type_bits(nir_op.input_types[i])
347            if src_type_bits != 0:
348               assert src_bits == src_type_bits
349            else:
350               assert val.common_size == 0 or src_bits == val.common_size
351               val.common_size = src_bits
352
353         dst_type_bits = type_bits(nir_op.output_type)
354         if dst_type_bits != 0:
355            assert val.bit_size == 0 or val.bit_size == dst_type_bits
356            return dst_type_bits
357         else:
358            if val.common_size != 0:
359               assert val.bit_size == 0 or val.bit_size == val.common_size
360            else:
361               val.common_size = val.bit_size
362            return val.common_size
363
364   def _propagate_bit_class_down(self, val, bit_class):
365      if isinstance(val, Constant):
366         assert val.bit_size == 0 or val.bit_size == bit_class
367
368      elif isinstance(val, Variable):
369         assert val.bit_size == 0 or val.bit_size == bit_class
370         self._set_var_bit_class(val.index, bit_class)
371
372      elif isinstance(val, Expression):
373         nir_op = opcodes[val.opcode]
374         dst_type_bits = type_bits(nir_op.output_type)
375         if dst_type_bits != 0:
376            assert bit_class == 0 or bit_class == dst_type_bits
377         else:
378            assert val.common_size == 0 or val.common_size == bit_class
379            val.common_size = bit_class
380
381         if val.common_size:
382            common_class = val.common_size
383         elif nir_op.num_inputs:
384            # If we got here then we have no idea what the actual size is.
385            # Instead, we use a generic class
386            common_class = self._new_class()
387
388         for i in range(nir_op.num_inputs):
389            src_type_bits = type_bits(nir_op.input_types[i])
390            if src_type_bits != 0:
391               self._propagate_bit_class_down(val.sources[i], src_type_bits)
392            else:
393               self._propagate_bit_class_down(val.sources[i], common_class)
394
395   def _validate_bit_class_up(self, val):
396      if isinstance(val, Constant):
397         return val.bit_size
398
399      elif isinstance(val, Variable):
400         var_class = self._get_var_bit_class(val.index)
401         # By the time we get to validation, every variable should have a class
402         assert var_class != 0
403
404         # If we have an explicit size provided by the user, the variable
405         # *must* exactly match the search.  It cannot be implicitly sized
406         # because otherwise we could end up with a conflict at runtime.
407         assert val.bit_size == 0 or val.bit_size == var_class
408
409         return var_class
410
411      elif isinstance(val, Expression):
412         nir_op = opcodes[val.opcode]
413         val.common_class = 0
414         for i in range(nir_op.num_inputs):
415            src_class = self._validate_bit_class_up(val.sources[i])
416            if src_class == 0:
417               continue
418
419            src_type_bits = type_bits(nir_op.input_types[i])
420            if src_type_bits != 0:
421               assert src_class == src_type_bits
422            else:
423               assert val.common_class == 0 or src_class == val.common_class
424               val.common_class = src_class
425
426         dst_type_bits = type_bits(nir_op.output_type)
427         if dst_type_bits != 0:
428            assert val.bit_size == 0 or val.bit_size == dst_type_bits
429            return dst_type_bits
430         else:
431            if val.common_class != 0:
432               assert val.bit_size == 0 or val.bit_size == val.common_class
433            else:
434               val.common_class = val.bit_size
435            return val.common_class
436
437   def _validate_bit_class_down(self, val, bit_class):
438      # At this point, everything *must* have a bit class.  Otherwise, we have
439      # a value we don't know how to define.
440      assert bit_class != 0
441
442      if isinstance(val, Constant):
443         assert val.bit_size == 0 or val.bit_size == bit_class
444
445      elif isinstance(val, Variable):
446         assert val.bit_size == 0 or val.bit_size == bit_class
447
448      elif isinstance(val, Expression):
449         nir_op = opcodes[val.opcode]
450         dst_type_bits = type_bits(nir_op.output_type)
451         if dst_type_bits != 0:
452            assert bit_class == dst_type_bits
453         else:
454            assert val.common_class == 0 or val.common_class == bit_class
455            val.common_class = bit_class
456
457         for i in range(nir_op.num_inputs):
458            src_type_bits = type_bits(nir_op.input_types[i])
459            if src_type_bits != 0:
460               self._validate_bit_class_down(val.sources[i], src_type_bits)
461            else:
462               self._validate_bit_class_down(val.sources[i], val.common_class)
463
464_optimization_ids = itertools.count()
465
466condition_list = ['true']
467
468class SearchAndReplace(object):
469   def __init__(self, transform):
470      self.id = _optimization_ids.next()
471
472      search = transform[0]
473      replace = transform[1]
474      if len(transform) > 2:
475         self.condition = transform[2]
476      else:
477         self.condition = 'true'
478
479      if self.condition not in condition_list:
480         condition_list.append(self.condition)
481      self.condition_index = condition_list.index(self.condition)
482
483      varset = VarSet()
484      if isinstance(search, Expression):
485         self.search = search
486      else:
487         self.search = Expression(search, "search{0}".format(self.id), varset)
488
489      varset.lock()
490
491      if isinstance(replace, Value):
492         self.replace = replace
493      else:
494         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
495
496      BitSizeValidator(varset).validate(self.search, self.replace)
497
498_algebraic_pass_template = mako.template.Template("""
499#include "nir.h"
500#include "nir_search.h"
501#include "nir_search_helpers.h"
502
503#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
504#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
505
506struct transform {
507   const nir_search_expression *search;
508   const nir_search_value *replace;
509   unsigned condition_offset;
510};
511
512#endif
513
514% for (opcode, xform_list) in xform_dict.iteritems():
515% for xform in xform_list:
516   ${xform.search.render()}
517   ${xform.replace.render()}
518% endfor
519
520static const struct transform ${pass_name}_${opcode}_xforms[] = {
521% for xform in xform_list:
522   { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
523% endfor
524};
525% endfor
526
527static bool
528${pass_name}_block(nir_block *block, const bool *condition_flags,
529                   void *mem_ctx)
530{
531   bool progress = false;
532
533   nir_foreach_instr_reverse_safe(instr, block) {
534      if (instr->type != nir_instr_type_alu)
535         continue;
536
537      nir_alu_instr *alu = nir_instr_as_alu(instr);
538      if (!alu->dest.dest.is_ssa)
539         continue;
540
541      switch (alu->op) {
542      % for opcode in xform_dict.keys():
543      case nir_op_${opcode}:
544         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
545            const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
546            if (condition_flags[xform->condition_offset] &&
547                nir_replace_instr(alu, xform->search, xform->replace,
548                                  mem_ctx)) {
549               progress = true;
550               break;
551            }
552         }
553         break;
554      % endfor
555      default:
556         break;
557      }
558   }
559
560   return progress;
561}
562
563static bool
564${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
565{
566   void *mem_ctx = ralloc_parent(impl);
567   bool progress = false;
568
569   nir_foreach_block_reverse(block, impl) {
570      progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
571   }
572
573   if (progress)
574      nir_metadata_preserve(impl, nir_metadata_block_index |
575                                  nir_metadata_dominance);
576
577   return progress;
578}
579
580
581bool
582${pass_name}(nir_shader *shader)
583{
584   bool progress = false;
585   bool condition_flags[${len(condition_list)}];
586   const nir_shader_compiler_options *options = shader->options;
587   (void) options;
588
589   % for index, condition in enumerate(condition_list):
590   condition_flags[${index}] = ${condition};
591   % endfor
592
593   nir_foreach_function(function, shader) {
594      if (function->impl)
595         progress |= ${pass_name}_impl(function->impl, condition_flags);
596   }
597
598   return progress;
599}
600""")
601
602class AlgebraicPass(object):
603   def __init__(self, pass_name, transforms):
604      self.xform_dict = {}
605      self.pass_name = pass_name
606
607      error = False
608
609      for xform in transforms:
610         if not isinstance(xform, SearchAndReplace):
611            try:
612               xform = SearchAndReplace(xform)
613            except:
614               print("Failed to parse transformation:", file=sys.stderr)
615               print("  " + str(xform), file=sys.stderr)
616               traceback.print_exc(file=sys.stderr)
617               print('', file=sys.stderr)
618               error = True
619               continue
620
621         if xform.search.opcode not in self.xform_dict:
622            self.xform_dict[xform.search.opcode] = []
623
624         self.xform_dict[xform.search.opcode].append(xform)
625
626      if error:
627         sys.exit(1)
628
629   def render(self):
630      return _algebraic_pass_template.render(pass_name=self.pass_name,
631                                             xform_dict=self.xform_dict,
632                                             condition_list=condition_list)
633