1 /*
2  * Copyright © 2014 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Jason Ekstrand (jason@jlekstrand.net)
25  *
26  */
27 
28 #include <inttypes.h>
29 #include "nir_search.h"
30 
31 struct match_state {
32    bool inexact_match;
33    bool has_exact_alu;
34    unsigned variables_seen;
35    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
36 };
37 
38 static bool
39 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
40                  unsigned num_components, const uint8_t *swizzle,
41                  struct match_state *state);
42 
43 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
44 
45 /**
46  * Check if a source produces a value of the given type.
47  *
48  * Used for satisfying 'a@type' constraints.
49  */
50 static bool
src_is_type(nir_src src,nir_alu_type type)51 src_is_type(nir_src src, nir_alu_type type)
52 {
53    assert(type != nir_type_invalid);
54 
55    if (!src.is_ssa)
56       return false;
57 
58    /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
59    if (nir_alu_type_get_base_type(type) == nir_type_bool)
60       type = nir_type_bool;
61 
62    if (src.ssa->parent_instr->type == nir_instr_type_alu) {
63       nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
64       nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
65 
66       if (type == nir_type_bool) {
67          switch (src_alu->op) {
68          case nir_op_iand:
69          case nir_op_ior:
70          case nir_op_ixor:
71             return src_is_type(src_alu->src[0].src, nir_type_bool) &&
72                    src_is_type(src_alu->src[1].src, nir_type_bool);
73          case nir_op_inot:
74             return src_is_type(src_alu->src[0].src, nir_type_bool);
75          default:
76             break;
77          }
78       }
79 
80       return nir_alu_type_get_base_type(output_type) == type;
81    } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
82       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
83 
84       if (type == nir_type_bool) {
85          return intr->intrinsic == nir_intrinsic_load_front_face ||
86                 intr->intrinsic == nir_intrinsic_load_helper_invocation;
87       }
88    }
89 
90    /* don't know */
91    return false;
92 }
93 
94 static bool
match_value(const nir_search_value * value,nir_alu_instr * instr,unsigned src,unsigned num_components,const uint8_t * swizzle,struct match_state * state)95 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
96             unsigned num_components, const uint8_t *swizzle,
97             struct match_state *state)
98 {
99    uint8_t new_swizzle[4];
100 
101    /* Searching only works on SSA values because, if it's not SSA, we can't
102     * know if the value changed between one instance of that value in the
103     * expression and another.  Also, the replace operation will place reads of
104     * that value right before the last instruction in the expression we're
105     * replacing so those reads will happen after the original reads and may
106     * not be valid if they're register reads.
107     */
108    if (!instr->src[src].src.is_ssa)
109       return false;
110 
111    /* If the source is an explicitly sized source, then we need to reset
112     * both the number of components and the swizzle.
113     */
114    if (nir_op_infos[instr->op].input_sizes[src] != 0) {
115       num_components = nir_op_infos[instr->op].input_sizes[src];
116       swizzle = identity_swizzle;
117    }
118 
119    for (unsigned i = 0; i < num_components; ++i)
120       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
121 
122    /* If the value has a specific bit size and it doesn't match, bail */
123    if (value->bit_size &&
124        nir_src_bit_size(instr->src[src].src) != value->bit_size)
125       return false;
126 
127    switch (value->type) {
128    case nir_search_value_expression:
129       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
130          return false;
131 
132       return match_expression(nir_search_value_as_expression(value),
133                               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
134                               num_components, new_swizzle, state);
135 
136    case nir_search_value_variable: {
137       nir_search_variable *var = nir_search_value_as_variable(value);
138       assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
139 
140       if (state->variables_seen & (1 << var->variable)) {
141          if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
142             return false;
143 
144          assert(!instr->src[src].abs && !instr->src[src].negate);
145 
146          for (unsigned i = 0; i < num_components; ++i) {
147             if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
148                return false;
149          }
150 
151          return true;
152       } else {
153          if (var->is_constant &&
154              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
155             return false;
156 
157          if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
158             return false;
159 
160          if (var->type != nir_type_invalid &&
161              !src_is_type(instr->src[src].src, var->type))
162             return false;
163 
164          state->variables_seen |= (1 << var->variable);
165          state->variables[var->variable].src = instr->src[src].src;
166          state->variables[var->variable].abs = false;
167          state->variables[var->variable].negate = false;
168 
169          for (unsigned i = 0; i < 4; ++i) {
170             if (i < num_components)
171                state->variables[var->variable].swizzle[i] = new_swizzle[i];
172             else
173                state->variables[var->variable].swizzle[i] = 0;
174          }
175 
176          return true;
177       }
178    }
179 
180    case nir_search_value_constant: {
181       nir_search_constant *const_val = nir_search_value_as_constant(value);
182 
183       if (!instr->src[src].src.is_ssa)
184          return false;
185 
186       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
187          return false;
188 
189       nir_load_const_instr *load =
190          nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
191 
192       switch (const_val->type) {
193       case nir_type_float:
194          for (unsigned i = 0; i < num_components; ++i) {
195             double val;
196             switch (load->def.bit_size) {
197             case 32:
198                val = load->value.f32[new_swizzle[i]];
199                break;
200             case 64:
201                val = load->value.f64[new_swizzle[i]];
202                break;
203             default:
204                unreachable("unknown bit size");
205             }
206 
207             if (val != const_val->data.d)
208                return false;
209          }
210          return true;
211 
212       case nir_type_int:
213       case nir_type_uint:
214       case nir_type_bool32:
215          switch (load->def.bit_size) {
216          case 32:
217             for (unsigned i = 0; i < num_components; ++i) {
218                if (load->value.u32[new_swizzle[i]] !=
219                    (uint32_t)const_val->data.u)
220                   return false;
221             }
222             return true;
223 
224          case 64:
225             for (unsigned i = 0; i < num_components; ++i) {
226                if (load->value.u64[new_swizzle[i]] != const_val->data.u)
227                   return false;
228             }
229             return true;
230 
231          default:
232             unreachable("unknown bit size");
233          }
234 
235       default:
236          unreachable("Invalid alu source type");
237       }
238    }
239 
240    default:
241       unreachable("Invalid search value type");
242    }
243 }
244 
245 static bool
match_expression(const nir_search_expression * expr,nir_alu_instr * instr,unsigned num_components,const uint8_t * swizzle,struct match_state * state)246 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
247                  unsigned num_components, const uint8_t *swizzle,
248                  struct match_state *state)
249 {
250    if (expr->cond && !expr->cond(instr))
251       return false;
252 
253    if (instr->op != expr->opcode)
254       return false;
255 
256    assert(instr->dest.dest.is_ssa);
257 
258    if (expr->value.bit_size &&
259        instr->dest.dest.ssa.bit_size != expr->value.bit_size)
260       return false;
261 
262    state->inexact_match = expr->inexact || state->inexact_match;
263    state->has_exact_alu = instr->exact || state->has_exact_alu;
264    if (state->inexact_match && state->has_exact_alu)
265       return false;
266 
267    assert(!instr->dest.saturate);
268    assert(nir_op_infos[instr->op].num_inputs > 0);
269 
270    /* If we have an explicitly sized destination, we can only handle the
271     * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
272     * expression, we don't have the information right now to propagate that
273     * swizzle through.  We can only properly propagate swizzles if the
274     * instruction is vectorized.
275     */
276    if (nir_op_infos[instr->op].output_size != 0) {
277       for (unsigned i = 0; i < num_components; i++) {
278          if (swizzle[i] != i)
279             return false;
280       }
281    }
282 
283    /* Stash off the current variables_seen bitmask.  This way we can
284     * restore it prior to matching in the commutative case below.
285     */
286    unsigned variables_seen_stash = state->variables_seen;
287 
288    bool matched = true;
289    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
290       if (!match_value(expr->srcs[i], instr, i, num_components,
291                        swizzle, state)) {
292          matched = false;
293          break;
294       }
295    }
296 
297    if (matched)
298       return true;
299 
300    if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
301       assert(nir_op_infos[instr->op].num_inputs == 2);
302 
303       /* Restore the variables_seen bitmask.  If we don't do this, then we
304        * could end up with an erroneous failure due to variables found in the
305        * first match attempt above not matching those in the second.
306        */
307       state->variables_seen = variables_seen_stash;
308 
309       if (!match_value(expr->srcs[0], instr, 1, num_components,
310                        swizzle, state))
311          return false;
312 
313       return match_value(expr->srcs[1], instr, 0, num_components,
314                          swizzle, state);
315    } else {
316       return false;
317    }
318 }
319 
320 typedef struct bitsize_tree {
321    unsigned num_srcs;
322    struct bitsize_tree *srcs[4];
323 
324    unsigned common_size;
325    bool is_src_sized[4];
326    bool is_dest_sized;
327 
328    unsigned dest_size;
329    unsigned src_size[4];
330 } bitsize_tree;
331 
332 static bitsize_tree *
build_bitsize_tree(void * mem_ctx,struct match_state * state,const nir_search_value * value)333 build_bitsize_tree(void *mem_ctx, struct match_state *state,
334                    const nir_search_value *value)
335 {
336    bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
337 
338    switch (value->type) {
339    case nir_search_value_expression: {
340       nir_search_expression *expr = nir_search_value_as_expression(value);
341       nir_op_info info = nir_op_infos[expr->opcode];
342       tree->num_srcs = info.num_inputs;
343       tree->common_size = 0;
344       for (unsigned i = 0; i < info.num_inputs; i++) {
345          tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
346          if (tree->is_src_sized[i])
347             tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
348          tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
349       }
350       tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
351       if (tree->is_dest_sized)
352          tree->dest_size = nir_alu_type_get_type_size(info.output_type);
353       break;
354    }
355 
356    case nir_search_value_variable: {
357       nir_search_variable *var = nir_search_value_as_variable(value);
358       tree->num_srcs = 0;
359       tree->is_dest_sized = true;
360       tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
361       break;
362    }
363 
364    case nir_search_value_constant: {
365       tree->num_srcs = 0;
366       tree->is_dest_sized = false;
367       tree->common_size = 0;
368       break;
369    }
370    }
371 
372    if (value->bit_size) {
373       assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
374       tree->common_size = value->bit_size;
375    }
376 
377    return tree;
378 }
379 
380 static unsigned
bitsize_tree_filter_up(bitsize_tree * tree)381 bitsize_tree_filter_up(bitsize_tree *tree)
382 {
383    for (unsigned i = 0; i < tree->num_srcs; i++) {
384       unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
385       if (src_size == 0)
386          continue;
387 
388       if (tree->is_src_sized[i]) {
389          assert(src_size == tree->src_size[i]);
390       } else if (tree->common_size != 0) {
391          assert(src_size == tree->common_size);
392          tree->src_size[i] = src_size;
393       } else {
394          tree->common_size = src_size;
395          tree->src_size[i] = src_size;
396       }
397    }
398 
399    if (tree->num_srcs && tree->common_size) {
400       if (tree->dest_size == 0)
401          tree->dest_size = tree->common_size;
402       else if (!tree->is_dest_sized)
403          assert(tree->dest_size == tree->common_size);
404 
405       for (unsigned i = 0; i < tree->num_srcs; i++) {
406          if (!tree->src_size[i])
407             tree->src_size[i] = tree->common_size;
408       }
409    }
410 
411    return tree->dest_size;
412 }
413 
414 static void
bitsize_tree_filter_down(bitsize_tree * tree,unsigned size)415 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
416 {
417    if (tree->dest_size)
418       assert(tree->dest_size == size);
419    else
420       tree->dest_size = size;
421 
422    if (!tree->is_dest_sized) {
423       if (tree->common_size)
424          assert(tree->common_size == size);
425       else
426          tree->common_size = size;
427    }
428 
429    for (unsigned i = 0; i < tree->num_srcs; i++) {
430       if (!tree->src_size[i]) {
431          assert(tree->common_size);
432          tree->src_size[i] = tree->common_size;
433       }
434       bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
435    }
436 }
437 
438 static nir_alu_src
construct_value(const nir_search_value * value,unsigned num_components,bitsize_tree * bitsize,struct match_state * state,nir_instr * instr,void * mem_ctx)439 construct_value(const nir_search_value *value,
440                 unsigned num_components, bitsize_tree *bitsize,
441                 struct match_state *state,
442                 nir_instr *instr, void *mem_ctx)
443 {
444    switch (value->type) {
445    case nir_search_value_expression: {
446       const nir_search_expression *expr = nir_search_value_as_expression(value);
447 
448       if (nir_op_infos[expr->opcode].output_size != 0)
449          num_components = nir_op_infos[expr->opcode].output_size;
450 
451       nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
452       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
453                         bitsize->dest_size, NULL);
454       alu->dest.write_mask = (1 << num_components) - 1;
455       alu->dest.saturate = false;
456 
457       /* We have no way of knowing what values in a given search expression
458        * map to a particular replacement value.  Therefore, if the
459        * expression we are replacing has any exact values, the entire
460        * replacement should be exact.
461        */
462       alu->exact = state->has_exact_alu;
463 
464       for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
465          /* If the source is an explicitly sized source, then we need to reset
466           * the number of components to match.
467           */
468          if (nir_op_infos[alu->op].input_sizes[i] != 0)
469             num_components = nir_op_infos[alu->op].input_sizes[i];
470 
471          alu->src[i] = construct_value(expr->srcs[i],
472                                        num_components, bitsize->srcs[i],
473                                        state, instr, mem_ctx);
474       }
475 
476       nir_instr_insert_before(instr, &alu->instr);
477 
478       nir_alu_src val;
479       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
480       val.negate = false;
481       val.abs = false,
482       memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
483 
484       return val;
485    }
486 
487    case nir_search_value_variable: {
488       const nir_search_variable *var = nir_search_value_as_variable(value);
489       assert(state->variables_seen & (1 << var->variable));
490 
491       nir_alu_src val = { NIR_SRC_INIT };
492       nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
493 
494       assert(!var->is_constant);
495 
496       return val;
497    }
498 
499    case nir_search_value_constant: {
500       const nir_search_constant *c = nir_search_value_as_constant(value);
501       nir_load_const_instr *load =
502          nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
503 
504       switch (c->type) {
505       case nir_type_float:
506          load->def.name = ralloc_asprintf(load, "%f", c->data.d);
507          switch (bitsize->dest_size) {
508          case 32:
509             load->value.f32[0] = c->data.d;
510             break;
511          case 64:
512             load->value.f64[0] = c->data.d;
513             break;
514          default:
515             unreachable("unknown bit size");
516          }
517          break;
518 
519       case nir_type_int:
520          load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
521          switch (bitsize->dest_size) {
522          case 32:
523             load->value.i32[0] = c->data.i;
524             break;
525          case 64:
526             load->value.i64[0] = c->data.i;
527             break;
528          default:
529             unreachable("unknown bit size");
530          }
531          break;
532 
533       case nir_type_uint:
534          load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
535          switch (bitsize->dest_size) {
536          case 32:
537             load->value.u32[0] = c->data.u;
538             break;
539          case 64:
540             load->value.u64[0] = c->data.u;
541             break;
542          default:
543             unreachable("unknown bit size");
544          }
545          break;
546 
547       case nir_type_bool32:
548          load->value.u32[0] = c->data.u;
549          break;
550       default:
551          unreachable("Invalid alu source type");
552       }
553 
554       nir_instr_insert_before(instr, &load->instr);
555 
556       nir_alu_src val;
557       val.src = nir_src_for_ssa(&load->def);
558       val.negate = false;
559       val.abs = false,
560       memset(val.swizzle, 0, sizeof val.swizzle);
561 
562       return val;
563    }
564 
565    default:
566       unreachable("Invalid search value type");
567    }
568 }
569 
570 nir_alu_instr *
nir_replace_instr(nir_alu_instr * instr,const nir_search_expression * search,const nir_search_value * replace,void * mem_ctx)571 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
572                   const nir_search_value *replace, void *mem_ctx)
573 {
574    uint8_t swizzle[4] = { 0, 0, 0, 0 };
575 
576    for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
577       swizzle[i] = i;
578 
579    assert(instr->dest.dest.is_ssa);
580 
581    struct match_state state;
582    state.inexact_match = false;
583    state.has_exact_alu = false;
584    state.variables_seen = 0;
585 
586    if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
587                          swizzle, &state))
588       return NULL;
589 
590    void *bitsize_ctx = ralloc_context(NULL);
591    bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
592    bitsize_tree_filter_up(tree);
593    bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
594 
595    /* Inserting a mov may be unnecessary.  However, it's much easier to
596     * simply let copy propagation clean this up than to try to go through
597     * and rewrite swizzles ourselves.
598     */
599    nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
600    mov->dest.write_mask = instr->dest.write_mask;
601    nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
602                      instr->dest.dest.ssa.num_components,
603                      instr->dest.dest.ssa.bit_size, NULL);
604 
605    mov->src[0] = construct_value(replace,
606                                  instr->dest.dest.ssa.num_components, tree,
607                                  &state, &instr->instr, mem_ctx);
608    nir_instr_insert_before(&instr->instr, &mov->instr);
609 
610    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
611                             nir_src_for_ssa(&mov->dest.dest.ssa));
612 
613    /* We know this one has no more uses because we just rewrote them all,
614     * so we can remove it.  The rest of the matched expression, however, we
615     * don't know so much about.  We'll just let dead code clean them up.
616     */
617    nir_instr_remove(&instr->instr);
618 
619    ralloc_free(bitsize_ctx);
620 
621    return mov;
622 }
623